diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 77e5558..ed7e38f 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,8 +1,12 @@ +from typing import Optional from torch import nn import torch.nn.functional as F import math +from enhancer.models.model import Model +from enhancer.data.dataset import EnhancerDataset from enhancer.utils.io import Audio as audio +from enhancer.utils.utils import merge_dict class DeLSTM(nn.Module): def __init__( @@ -25,64 +29,68 @@ class DeLSTM(nn.Module): return output,(h,c) -class Demucs(nn.Module): +class Demucs(Model): + + ED_DEFAULTS = { + "intial_output_channels":48, + "kernel_size":8, + "stride":1, + "depth":5, + "glu":True, + "growth_factor":2, + } + LSTM_DEFAULTS = { + "bidirectional":True, + "num_layers":2, + } def __init__( self, - c_in:int=1, - c_out:int=1, - hidden:int=48, - kernel_size:int=8, - stride:int=4, - growth_factor:int=2, - depth:int = 5, - glu:bool = True, - bidirectional:bool=True, + encoder_decoder:Optional[dict]=None, + lstm:Optional[dict]=None, + num_channels:int=1, resample:int=4, - sampling_rate = 16000 + sampling_rate = 16000, + dataset:Optional[EnhancerDataset]=None, ): - super().__init__() - self.c_in = c_in - self.c_out = c_out - self.hidden = hidden - self.growth_factor = growth_factor - self.stride = stride - self.kernel_size = kernel_size - self.depth = depth - self.bidirectional = bidirectional - self.activation = nn.GLU(1) if glu else nn.ReLU() - self.resample = resample - self.sampling_rate = sampling_rate - multi_factor = 2 if glu else 1 + super().__init__(num_channels=num_channels, + sampling_rate=sampling_rate,dataset=dataset) + + encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS,lstm) + self.save_hyperparameters("encoder_decoder","lstm","resample") + + hidden = encoder_decoder["initial_channel_output"] + activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU() + multi_factor = 2 if encoder_decoder["glu"] else 1 self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() - for layer in range(self.depth): + for layer in range(encoder_decoder["depth"]): - encoder_layer = [nn.Conv1d(c_in,hidden,kernel_size,stride), + encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]), nn.ReLU(), - nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1), - self.activation] + nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1), + activation] encoder_layer = nn.Sequential(*encoder_layer) self.encoder.append(encoder_layer) - decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1), - self.activation, - nn.ConvTranspose1d(hidden,c_out,kernel_size,stride) + decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1), + activation, + nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"]) ] if layer>0: decoder_layer.append(nn.ReLU()) decoder_layer = nn.Sequential(*decoder_layer) self.decoder.insert(0,decoder_layer) - c_out = hidden - c_in = hidden + num_channels = hidden hidden = self.growth_factor * hidden - self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) + self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"]) def forward(self,mixed_signal): @@ -91,14 +99,13 @@ class Demucs(nn.Module): length = mixed_signal.shape[-1] x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) - if self.resample>1: - x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate, - target_sr=int(self.sampling_rate * self.resample)) + if self.hparams.resample>1: + x = audio.pt_resample_audio(audio=x, sr=self.hparams.sampling_rate, + target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) encoder_outputs = [] for encoder in self.encoder: x = encoder(x) - print(x.shape) encoder_outputs.append(x) x = x.permute(0,2,1) x,_ = self.de_lstm(x) @@ -109,23 +116,23 @@ class Demucs(nn.Module): x += skip_connection[..., :x.shape[-1]] x = decoder(x) - if self.resample > 1: - x = audio.pt_resample_audio(x,int(self.sampling_rate * self.resample), - self.sampling_rate) + if self.hparams.resample > 1: + x = audio.pt_resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), + self.hparams.sampling_rate) return x def get_padding_length(self,input_length): - input_length = math.ceil(input_length * self.resample) + input_length = math.ceil(input_length * self.hparams.resample) - for layer in range(self.depth): # encoder operation + for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1 input_length = max(1,input_length) - for layer in range(self.depth): # decoder operaration + for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration input_length = (input_length-1) * self.stride + self.kernel_size - input_length = math.ceil(input_length/self.resample) + input_length = math.ceil(input_length/self.hparams.resample) return int(input_length)