From b42ca28851f172797b28d83dc6f9c26e34b2d4a1 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 6 Sep 2022 20:44:19 +0530 Subject: [PATCH] fix shapes --- enhancer/models/demucs.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 618e9ab..b18a7af 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,6 +1,5 @@ -from typing import bool from torch import nn -import torch.functional as F +import torch.nn.functional as F import math from enhancer.utils.io import Audio as audio @@ -14,6 +13,7 @@ class DeLSTM(nn.Module): bidirectional:bool=True ): + super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) dim = 2 if bidirectional else 1 self.linear = nn.Linear(dim*hidden_size,hidden_size) @@ -25,7 +25,7 @@ class DeLSTM(nn.Module): return output -class Demus(nn.Module): +class Demucs(nn.Module): def __init__( self, @@ -35,10 +35,10 @@ class Demus(nn.Module): kernel_size:int=8, stride:int=4, growth_factor:int=2, - depth:int = 6, + depth:int = 5, glu:bool = True, bidirectional:bool=True, - resample:int=2, + resample:int=4, sampling_rate = 16000 ): @@ -65,8 +65,8 @@ class Demus(nn.Module): nn.ReLU(), nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1), self.activation] - encoder_layer = nn.Sequential(encoder_layer) - self.encoder.append(*encoder_layer) + encoder_layer = nn.Sequential(*encoder_layer) + self.encoder.append(encoder_layer) decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1), self.activation, @@ -87,25 +87,27 @@ class Demus(nn.Module): def forward(self,mixed_signal): length = mixed_signal.shape[-1] - x = F.pad((0,self.get_padding_length(length) - length)) + x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) if self.resample>1: - x = audio.resample_audio(audio=x, - sampling_rate = int(self.sampling_rate * self.resample)) - + x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate, + target_sr=int(self.sampling_rate * self.resample)) + print("resampled->",x.shape) encoder_outputs = [] for encoder in self.encoder: x = encoder(x) + print(x.shape) encoder_outputs.append(x) - - x,_ = self.de_lstm(x) + x = x.permute(0,2,1) + x = self.de_lstm(x) + x = x.permute(0,2,1) for decoder in self.decoder: skip_connection = encoder_outputs.pop(-1) x += skip_connection[..., :x.shape[-1]] x = decoder(x) if self.resample > 1: - x = audio.resample_audio(x,int(self.sampling_rate * self.resample), + x = audio.pt_resample_audio(x,int(self.sampling_rate * self.resample), self.sampling_rate) return x