from torch.utils.data import Dataset, DataLoader
from librosa.feature import mfcc
import numpy as np
import librosa
import re
import os
def read_label():
root_path = "/media/dfy/fc0b6513-c379-4548-b391-876575f1493f/home/dfy/Downloads/110_time/"
exist_files=os.listdir(root_path + "qiegezhuan")
with open(root_path + "zong_new.txt", "r", encoding="utf-8") as f:
data = f.readlines()
# 清洗
path_list_wav_dict = {i.split("\t")[1].strip().replace("/noise", ""):root_path + i.split("\t")[0] for i in data }
# 控制语音长度
path_list_wav_dict = {v:"".join(re.compile('[\u4e00-\u9fa5]').findall(k)) for k,v in path_list_wav_dict.items() if len(k)<17}
path_list_wav_dict_ = {"".join(re.compile('[\u4e00-\u9fa5]').findall(v)):k for k,v in path_list_wav_dict.items() if k.split("/")[-1] in exist_files}

r_h_dict=[path_list_wav_dict.get(root_path+"qiegezhuan/"+i,"") for i in exist_files]
chinese_dict={i: v for i, v in enumerate(["<p>", "<e>"] + sorted(set("".join(r_h_dict))))}
chinese_dict_f={v:i for i, v in enumerate(["<p>", "<e>"] + sorted(set("".join(r_h_dict))))}

return [[k,[chinese_dict_f[i] for i in v ]+[1]+[0 for _ in range(31-len(v))]] for v,k in path_list_wav_dict_.items() if len(v)>5]

def padding(data):
pad=np.ones([32,512])*-1
# if data.shape[-1]>512:
# print("data_s_{}".format(data.shape))
pad[:31,:data.shape[-1]]=data[:31,:512]
return pad

def getMELspectrogram_mfcc(audio, sample_rate):
audio, s = librosa.load(audio)
librosa_mfcc = mfcc(y=audio,sr=sample_rate,n_mfcc=40,n_mels=31,n_fft=25*16,hop_length=15*16)
# print(librosa_mfcc.shape[-1])
librosa_mfcc[0]=librosa.feature.rms(audio,hop_length=15*16,frame_length=25*16)
return librosa_mfcc
class MyDataset (Dataset):
# 构造函数带有默认参数
def __init__(self,name):
self.train_data= read_label()
if name=="val":
self.train_data=self.train_data[:]
else:
self.train_data = self.train_data[:4000]
np.random.shuffle(self.train_data)
def __getitem__(self, index):
fn,label = self.train_data[index]
try:
fn=padding(getMELspectrogram_mfcc(fn,16000))
except Exception as e:
print(fn, e)
if "too small to resample from" in e.args[0]:
os.remove(fn)

fn=np.zeros([32,512])
label=[0]*32

return fn,np.array(label)
def __len__(self):
return len(self.train_data)
if __name__ == '__main__':
dataset=MyDataset("train")
train_loader = DataLoader( dataset,batch_size=2, shuffle=True,num_workers=4)
for i,d in enumerate(train_loader):
print(i)

首先看上面的代码是固定最大长度的图谱制作,且pad的采取-1而不是字典表中的0由于ce 采取了忽略pad 故而要采用非字典也就是标签中没有的进行pad这样输入的编码值才会是均匀分布的