diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index b18a7af..77e5558 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -23,7 +23,7 @@ class DeLSTM(nn.Module): output,(h,c) = self.lstm(x) output = self.linear(output) - return output + return output,(h,c) class Demucs(nn.Module): @@ -85,20 +85,23 @@ class Demucs(nn.Module): self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) def forward(self,mixed_signal): - + + if mixed_signal.dim() == 2: + mixed_signal = mixed_signal.unsqueeze(1) + length = mixed_signal.shape[-1] x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) if self.resample>1: x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate, target_sr=int(self.sampling_rate * self.resample)) - print("resampled->",x.shape) + encoder_outputs = [] for encoder in self.encoder: x = encoder(x) print(x.shape) encoder_outputs.append(x) x = x.permute(0,2,1) - x = self.de_lstm(x) + x,_ = self.de_lstm(x) x = x.permute(0,2,1) for decoder in self.decoder: