from base64 import encode from turtle import forward from typing import Optional, Union, List from torch import nn import torch.nn.functional as F import math from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict class DemucsLSTM(nn.Module): def __init__( self, input_size:int, hidden_size:int, num_layers:int, bidirectional:bool=True ): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) dim = 2 if bidirectional else 1 self.linear = nn.Linear(dim*hidden_size,hidden_size) def forward(self,x): output,(h,c) = self.lstm(x) output = self.linear(output) return output,(h,c) class DemucsEncoder(nn.Module): def __init__( self, num_channels:int, hidden_size:int, kernel_size:int, stride:int=1, glu:bool=False, ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.encoder = nn.Sequential( nn.Conv1d(num_channels,hidden_size,kernel_size,stride), nn.ReLU(), nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1), activation ) def forward(self,waveform): return self.encoder(waveform) class DemucsDecoder(nn.Module): def __init__( self, num_channels:int, hidden_size:int, kernel_size:int, stride:int=1, glu:bool=False, layer:int=0 ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.decoder = nn.Sequential( nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1), activation, nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) ) if layer>0: self.decoder.add_module("4", nn.ReLU()) def forward(self,waveform,): out = self.decoder(waveform) return out class Demucs(Model): ED_DEFAULTS = { "initial_output_channels":48, "kernel_size":8, "stride":1, "depth":5, "glu":True, "growth_factor":2, } LSTM_DEFAULTS = { "bidirectional":True, "num_layers":2, } def __init__( self, encoder_decoder:Optional[dict]=None, lstm:Optional[dict]=None, num_channels:int=1, resample:int=4, sampling_rate = 16000, lr:float=1e-3, dataset:Optional[EnhancerDataset]=None, loss:Union[str, List] = "mse", metric:Union[str, List] = "mse" ): duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric) encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) lstm = merge_dict(self.LSTM_DEFAULTS,lstm) self.save_hyperparameters("encoder_decoder","lstm","resample") hidden = encoder_decoder["initial_output_channels"] self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for layer in range(encoder_decoder["depth"]): encoder_layer = DemucsEncoder(num_channels=num_channels, hidden_size=hidden, kernel_size=encoder_decoder["kernel_size"], stride=encoder_decoder["stride"], glu=encoder_decoder["glu"], ) self.encoder.append(encoder_layer) decoder_layer = DemucsDecoder(num_channels=num_channels, hidden_size=hidden, kernel_size=encoder_decoder["kernel_size"], stride=1, glu=encoder_decoder["glu"], layer=layer ) self.decoder.insert(0,decoder_layer) num_channels = hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden self.de_lstm = DemucsLSTM(input_size=num_channels, hidden_size=num_channels, num_layers=lstm["num_layers"], bidirectional=lstm["bidirectional"] ) def forward(self,waveform): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) if waveform.size(1)!=1: raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") length = waveform.shape[-1] x = F.pad(waveform, (0,self.get_padding_length(length) - length)) if self.hparams.resample>1: x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) encoder_outputs = [] for encoder in self.encoder: x = encoder(x) encoder_outputs.append(x) x = x.permute(0,2,1) x,_ = self.de_lstm(x) x = x.permute(0,2,1) for decoder in self.decoder: skip_connection = encoder_outputs.pop(-1) x += skip_connection[..., :x.shape[-1]] x = decoder(x) if self.hparams.resample > 1: x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), self.hparams.sampling_rate) return x def get_padding_length(self,input_length): input_length = math.ceil(input_length * self.hparams.resample) for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1 input_length = max(1,input_length) for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"] input_length = math.ceil(input_length/self.hparams.resample) return int(input_length)