import torch
class Speech_Detect(torch.nn.Module):
    def __init__(self,min_l=3,max_l=20,x_l=10):
        super(Speech_Detect,self).__init__()
        self.in_dict=dict()
        self.out_dict=dict()
        for i in range(min_l,max_l):
            self.in_dict[i]=torch.nn.Linear(i,x_l)
            self.out_dict[i]=torch.nn.Linear(x_l,i)



    def forward(self,x,label_len):
        x=self.in_dict[x.shape[0]](x.T)
        x=self.out_dict[label_len](x)
        return x.T


net = Speech_Detect()

for _ in  range(100):
    data=torch.randn([12,12])
    out=net(data,11)
    print(out.shape)


if __name__ == '__main__':
    pass