return hidden & cell states

This commit is contained in:
shahules786 2022-09-07 10:02:56 +05:30
parent fb6d0afee0
commit 4edc90deb0
1 changed files with 7 additions and 4 deletions

View File

@ -23,7 +23,7 @@ class DeLSTM(nn.Module):
output,(h,c) = self.lstm(x) output,(h,c) = self.lstm(x)
output = self.linear(output) output = self.linear(output)
return output return output,(h,c)
class Demucs(nn.Module): class Demucs(nn.Module):
@ -85,20 +85,23 @@ class Demucs(nn.Module):
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
def forward(self,mixed_signal): def forward(self,mixed_signal):
if mixed_signal.dim() == 2:
mixed_signal = mixed_signal.unsqueeze(1)
length = mixed_signal.shape[-1] length = mixed_signal.shape[-1]
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
if self.resample>1: if self.resample>1:
x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate, x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate,
target_sr=int(self.sampling_rate * self.resample)) target_sr=int(self.sampling_rate * self.resample))
print("resampled->",x.shape)
encoder_outputs = [] encoder_outputs = []
for encoder in self.encoder: for encoder in self.encoder:
x = encoder(x) x = encoder(x)
print(x.shape) print(x.shape)
encoder_outputs.append(x) encoder_outputs.append(x)
x = x.permute(0,2,1) x = x.permute(0,2,1)
x = self.de_lstm(x) x,_ = self.de_lstm(x)
x = x.permute(0,2,1) x = x.permute(0,2,1)
for decoder in self.decoder: for decoder in self.decoder: