From 3f40b54fc693d2311d0b1b531a12e7f22088bda6 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 21 Sep 2022 10:36:56 +0530 Subject: [PATCH] refactor encoder-decoder --- enhancer/models/demucs.py | 96 +++++++++++++++++++++++++++++++-------- 1 file changed, 76 insertions(+), 20 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 7e51c87..571a915 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,3 +1,5 @@ +from base64 import encode +from turtle import forward from typing import Optional, Union, List from torch import nn import torch.nn.functional as F @@ -8,7 +10,7 @@ from enhancer.data.dataset import EnhancerDataset from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict -class DeLSTM(nn.Module): +class DemucsLSTM(nn.Module): def __init__( self, input_size:int, @@ -29,6 +31,59 @@ class DeLSTM(nn.Module): return output,(h,c) + +class DemucsEncoder(nn.Module): + + def __init__( + self, + num_channels:int, + hidden_size:int, + kernel_size:int, + stride:int=1, + glu:bool=False, + ): + super().__init__() + activation = nn.GLU(1) if glu else nn.ReLU() + multi_factor = 2 if glu else 1 + self.encoder = nn.Sequential( + nn.Conv1d(num_channels,hidden_size,kernel_size,stride), + nn.ReLU(), + nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1), + activation + ) + + def forward(self,waveform): + + return self.encoder(waveform) + +class DemucsDecoder(nn.Module): + + def __init__( + self, + num_channels:int, + hidden_size:int, + kernel_size:int, + stride:int=1, + glu:bool=False, + layer:int=0 + ): + super().__init__() + activation = nn.GLU(1) if glu else nn.ReLU() + multi_factor = 2 if glu else 1 + self.decoder = nn.Sequential( + nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1), + activation, + nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) + ) + if layer>0: + self.decoder.add_module("4", nn.ReLU()) + + def forward(self,waveform,): + + out = self.decoder(waveform) + return out + + class Demucs(Model): ED_DEFAULTS = { @@ -56,44 +111,45 @@ class Demucs(Model): loss:Union[str, List] = "mse" ): + duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, - dataset=dataset,loss=loss) + dataset=dataset,duration=duration,loss=loss) encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) lstm = merge_dict(self.LSTM_DEFAULTS,lstm) self.save_hyperparameters("encoder_decoder","lstm","resample") - hidden = encoder_decoder["initial_output_channels"] - activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU() - multi_factor = 2 if encoder_decoder["glu"] else 1 - self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for layer in range(encoder_decoder["depth"]): - encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]), - nn.ReLU(), - nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1), - activation] - encoder_layer = nn.Sequential(*encoder_layer) + encoder_layer = DemucsEncoder(num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=encoder_decoder["stride"], + glu=encoder_decoder["glu"], + ) self.encoder.append(encoder_layer) - decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1), - activation, - nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"]) - ] - if layer>0: - decoder_layer.append(nn.ReLU()) - decoder_layer = nn.Sequential(*decoder_layer) + decoder_layer = DemucsDecoder(num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=1, + glu=encoder_decoder["glu"], + layer=layer + ) self.decoder.insert(0,decoder_layer) num_channels = hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden - - self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"]) + self.de_lstm = DemucsLSTM(input_size=num_channels, + hidden_size=num_channels, + num_layers=lstm["num_layers"], + bidirectional=lstm["bidirectional"] + ) def forward(self,mixed_signal):