From 409afc31fc6b0b22f26a046eb5ce6d58f0d83dc6 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 5 Sep 2022 17:12:03 +0530 Subject: [PATCH] demucs forward --- enhancer/models/demucs.py | 55 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 7ec9989..618e9ab 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,5 +1,9 @@ from typing import bool -from torch import nn +from torch import nn +import torch.functional as F +import math + +from enhancer.utils.io import Audio as audio class DeLSTM(nn.Module): def __init__( @@ -35,8 +39,10 @@ class Demus(nn.Module): glu:bool = True, bidirectional:bool=True, resample:int=2, + sampling_rate = 16000 ): + super().__init__() self.c_in = c_in self.c_out = c_out self.hidden = hidden @@ -46,10 +52,10 @@ class Demus(nn.Module): self.depth = depth self.bidirectional = bidirectional self.activation = nn.GLU(1) if glu else nn.ReLU() + self.resample = resample + self.sampling_rate = sampling_rate multi_factor = 2 if glu else 1 - ## do resampling - self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() @@ -78,7 +84,48 @@ class Demus(nn.Module): self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) - def forward(self,input): + def forward(self,mixed_signal): + + length = mixed_signal.shape[-1] + x = F.pad((0,self.get_padding_length(length) - length)) + if self.resample>1: + x = audio.resample_audio(audio=x, + sampling_rate = int(self.sampling_rate * self.resample)) + + encoder_outputs = [] + for encoder in self.encoder: + x = encoder(x) + encoder_outputs.append(x) + + x,_ = self.de_lstm(x) + + for decoder in self.decoder: + skip_connection = encoder_outputs.pop(-1) + x += skip_connection[..., :x.shape[-1]] + x = decoder(x) + + if self.resample > 1: + x = audio.resample_audio(x,int(self.sampling_rate * self.resample), + self.sampling_rate) + + return x + + def get_padding_length(self,input_length): + + input_length = math.ceil(input_length * self.resample) + + + for layer in range(self.depth): # encoder operation + input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1 + input_length = max(1,input_length) + for layer in range(self.depth): # decoder operaration + input_length = (input_length-1) * self.stride + self.kernel_size + input_length = math.ceil(input_length/self.resample) + + return int(input_length) + + +