import torch
class Speech_Detect(torch.nn.Module):
def __init__(self,min_l=3,max_l=20,x_l=10):
super(Speech_Detect,self).__init__()
self.in_dict=dict()
self.out_dict=dict()
for i in range(min_l,max_l):
self.in_dict[i]=torch.nn.Linear(i,x_l)
self.out_dict[i]=torch.nn.Linear(x_l,i)
def forward(self,x,label_len):
x=self.in_dict[x.shape[0]](x.T)
x=self.out_dict[label_len](x)
return x.T
net = Speech_Detect()
for _ in range(100):
data=torch.randn([12,12])
out=net(data,11)
print(out.shape)
if __name__ == '__main__':
pass