diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 7c9d8ff..76a0bf7 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -9,209 +9,255 @@ from enhancer.data.dataset import EnhancerDataset from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict + class DemucsLSTM(nn.Module): def __init__( self, - input_size:int, - hidden_size:int, - num_layers:int, - bidirectional:bool=True - + input_size: int, + hidden_size: int, + num_layers: int, + bidirectional: bool = True, ): super().__init__() - self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) + self.lstm = nn.LSTM( + input_size, hidden_size, num_layers, bidirectional=bidirectional + ) dim = 2 if bidirectional else 1 - self.linear = nn.Linear(dim*hidden_size,hidden_size) + self.linear = nn.Linear(dim * hidden_size, hidden_size) - def forward(self,x): + def forward(self, x): - output,(h,c) = self.lstm(x) + output, (h, c) = self.lstm(x) output = self.linear(output) - return output,(h,c) + return output, (h, c) class DemucsEncoder(nn.Module): - def __init__( self, - num_channels:int, - hidden_size:int, - kernel_size:int, - stride:int=1, - glu:bool=False, + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.encoder = nn.Sequential( - nn.Conv1d(num_channels,hidden_size,kernel_size,stride), + nn.Conv1d(num_channels, hidden_size, kernel_size, stride), nn.ReLU(), - nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1), - activation + nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), + activation, ) - def forward(self,waveform): - + def forward(self, waveform): + return self.encoder(waveform) -class DemucsDecoder(nn.Module): +class DemucsDecoder(nn.Module): def __init__( self, - num_channels:int, - hidden_size:int, - kernel_size:int, - stride:int=1, - glu:bool=False, - layer:int=0 + num_channels: int, + hidden_size: int, + kernel_size: int, + stride: int = 1, + glu: bool = False, + layer: int = 0, ): super().__init__() activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.decoder = nn.Sequential( - nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1), + nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), activation, - nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) + nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), ) - if layer>0: + if layer > 0: self.decoder.add_module("4", nn.ReLU()) - def forward(self,waveform,): + def forward( + self, + waveform, + ): out = self.decoder(waveform) return out class Demucs(Model): + """ + Demucs model from https://arxiv.org/pdf/1911.13254.pdf + parameters: + encoder_decoder: dict, optional + keyword arguments passsed to encoder decoder block + lstm : dict, optional + keyword arguments passsed to LSTM block + num_channels: int, defaults to 1 + number channels in input audio + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + + """ ED_DEFAULTS = { - "initial_output_channels":48, - "kernel_size":8, - "stride":1, - "depth":5, - "glu":True, - "growth_factor":2, + "initial_output_channels": 48, + "kernel_size": 8, + "stride": 1, + "depth": 5, + "glu": True, + "growth_factor": 2, } LSTM_DEFAULTS = { - "bidirectional":True, - "num_layers":2, + "bidirectional": True, + "num_layers": 2, } - + def __init__( self, - encoder_decoder:Optional[dict]=None, - lstm:Optional[dict]=None, - num_channels:int=1, - resample:int=4, - sampling_rate = 16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - loss:Union[str, List] = "mse", - metric:Union[str, List] = "mse" - - + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + num_channels: int = 1, + resample: int = 4, + sampling_rate=16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + loss: Union[str, List] = "mse", + metric: Union[str, List] = "mse", ): - duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) if dataset is not None: - if sampling_rate!=dataset.sampling_rate: - logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + if sampling_rate != dataset.sampling_rate: + logging.warn( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) sampling_rate = dataset.sampling_rate - super().__init__(num_channels=num_channels, - sampling_rate=sampling_rate,lr=lr, - dataset=dataset,duration=duration,loss=loss, metric=metric) - - encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) - lstm = merge_dict(self.LSTM_DEFAULTS,lstm) - self.save_hyperparameters("encoder_decoder","lstm","resample") + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + self.save_hyperparameters("encoder_decoder", "lstm", "resample") hidden = encoder_decoder["initial_output_channels"] self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for layer in range(encoder_decoder["depth"]): - encoder_layer = DemucsEncoder(num_channels=num_channels, - hidden_size=hidden, - kernel_size=encoder_decoder["kernel_size"], - stride=encoder_decoder["stride"], - glu=encoder_decoder["glu"], - ) + encoder_layer = DemucsEncoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=encoder_decoder["stride"], + glu=encoder_decoder["glu"], + ) self.encoder.append(encoder_layer) - decoder_layer = DemucsDecoder(num_channels=num_channels, - hidden_size=hidden, - kernel_size=encoder_decoder["kernel_size"], - stride=1, - glu=encoder_decoder["glu"], - layer=layer - ) - self.decoder.insert(0,decoder_layer) + decoder_layer = DemucsDecoder( + num_channels=num_channels, + hidden_size=hidden, + kernel_size=encoder_decoder["kernel_size"], + stride=1, + glu=encoder_decoder["glu"], + layer=layer, + ) + self.decoder.insert(0, decoder_layer) num_channels = hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden - - self.de_lstm = DemucsLSTM(input_size=num_channels, - hidden_size=num_channels, - num_layers=lstm["num_layers"], - bidirectional=lstm["bidirectional"] - ) - def forward(self,waveform): + self.de_lstm = DemucsLSTM( + input_size=num_channels, + hidden_size=num_channels, + num_layers=lstm["num_layers"], + bidirectional=lstm["bidirectional"], + ) + + def forward(self, waveform): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1)!=1: - raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") + if waveform.size(1) != 1: + raise TypeError( + f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" + ) length = waveform.shape[-1] - x = F.pad(waveform, (0,self.get_padding_length(length) - length)) - if self.hparams.resample>1: - x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, - target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) - + x = F.pad(waveform, (0, self.get_padding_length(length) - length)) + if self.hparams.resample > 1: + x = audio.resample_audio( + audio=x, + sr=self.hparams.sampling_rate, + target_sr=int( + self.hparams.sampling_rate * self.hparams.resample + ), + ) + encoder_outputs = [] for encoder in self.encoder: x = encoder(x) encoder_outputs.append(x) - x = x.permute(0,2,1) - x,_ = self.de_lstm(x) + x = x.permute(0, 2, 1) + x, _ = self.de_lstm(x) - x = x.permute(0,2,1) + x = x.permute(0, 2, 1) for decoder in self.decoder: skip_connection = encoder_outputs.pop(-1) - x += skip_connection[..., :x.shape[-1]] + x += skip_connection[..., : x.shape[-1]] x = decoder(x) - + if self.hparams.resample > 1: - x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), - self.hparams.sampling_rate) + x = audio.resample_audio( + x, + int(self.hparams.sampling_rate * self.hparams.resample), + self.hparams.sampling_rate, + ) return x - - def get_padding_length(self,input_length): + + def get_padding_length(self, input_length): input_length = math.ceil(input_length * self.hparams.resample) - - for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation - input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1 - input_length = max(1,input_length) - for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration - input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"] - input_length = math.ceil(input_length/self.hparams.resample) + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # encoder operation + input_length = ( + math.ceil( + (input_length - self.hparams.encoder_decoder["kernel_size"]) + / self.hparams.encoder_decoder["stride"] + ) + + 1 + ) + input_length = max(1, input_length) + for layer in range( + self.hparams.encoder_decoder["depth"] + ): # decoder operaration + input_length = (input_length - 1) * self.hparams.encoder_decoder[ + "stride" + ] + self.hparams.encoder_decoder["kernel_size"] + input_length = math.ceil(input_length / self.hparams.resample) return int(input_length) - - - - - - - - - - - - - \ No newline at end of file diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index f799352..4d5cc0a 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -7,76 +7,117 @@ from typing import Optional, Union, List from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset -class WavenetDecoder(nn.Module): +class WavenetDecoder(nn.Module): def __init__( self, - in_channels:int, - out_channels:int, - kernel_size:int=5, - padding:int=2, - stride:int=1, - dilation:int=1, + in_channels: int, + out_channels: int, + kernel_size: int = 5, + padding: int = 2, + stride: int = 1, + dilation: int = 1, ): - super(WavenetDecoder,self).__init__() + super(WavenetDecoder, self).__init__() self.decoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), nn.BatchNorm1d(out_channels), - nn.LeakyReLU(negative_slope=0.1) + nn.LeakyReLU(negative_slope=0.1), ) - - def forward(self,waveform): - + + def forward(self, waveform): + return self.decoder(waveform) -class WavenetEncoder(nn.Module): +class WavenetEncoder(nn.Module): def __init__( self, - in_channels:int, - out_channels:int, - kernel_size:int=15, - padding:int=7, - stride:int=1, - dilation:int=1, + in_channels: int, + out_channels: int, + kernel_size: int = 15, + padding: int = 7, + stride: int = 1, + dilation: int = 1, ): - super(WavenetEncoder,self).__init__() + super(WavenetEncoder, self).__init__() self.encoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), + nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ), nn.BatchNorm1d(out_channels), - nn.LeakyReLU(negative_slope=0.1) + nn.LeakyReLU(negative_slope=0.1), ) - - def forward( - self, - waveform - ): + def forward(self, waveform): return self.encoder(waveform) class WaveUnet(Model): + """ + Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf + parameters: + num_channels: int, defaults to 1 + number of channels in input audio + depth : int, defaults to 12 + depth of network + initial_output_channels: int, defaults to 24 + number of output channels in intial upsampling layer + sampling_rate: int, defaults to 16KHz + sampling rate of input audio + lr : float, defaults to 1e-3 + learning rate used for training + dataset: EnhancerDataset, optional + EnhancerDataset object containing train/validation data for training + duration : float, optional + chunk duration in seconds + loss : string or List of strings + loss function to be used, available ("mse","mae","SI-SDR") + metric : string or List of strings + metric function to be used, available ("mse","mae","SI-SDR") + """ def __init__( self, - num_channels:int=1, - depth:int=12, - initial_output_channels:int=24, - sampling_rate:int=16000, - lr:float=1e-3, - dataset:Optional[EnhancerDataset]=None, - duration:Optional[float]=None, + num_channels: int = 1, + depth: int = 12, + initial_output_channels: int = 24, + sampling_rate: int = 16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric:Union[str,List] = "mse" + metric: Union[str, List] = "mse", ): - duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) if dataset is not None: - if sampling_rate!=dataset.sampling_rate: - logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + if sampling_rate != dataset.sampling_rate: + logging.warn( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) sampling_rate = dataset.sampling_rate - super().__init__(num_channels=num_channels, - sampling_rate=sampling_rate,lr=lr, - dataset=dataset,duration=duration,loss=loss, metric=metric + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, ) self.save_hyperparameters("depth") self.encoders = nn.ModuleList() @@ -84,72 +125,76 @@ class WaveUnet(Model): out_channels = initial_output_channels for layer in range(depth): - encoder = WavenetEncoder(num_channels,out_channels) + encoder = WavenetEncoder(num_channels, out_channels) self.encoders.append(encoder) num_channels = out_channels out_channels += initial_output_channels - if layer == depth -1 : - decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels) + if layer == depth - 1: + decoder = WavenetDecoder( + depth * initial_output_channels + num_channels, num_channels + ) else: - decoder = WavenetDecoder(num_channels+out_channels,num_channels) + decoder = WavenetDecoder( + num_channels + out_channels, num_channels + ) - self.decoders.insert(0,decoder) + self.decoders.insert(0, decoder) bottleneck_dim = depth * initial_output_channels self.bottleneck = nn.Sequential( - nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, - padding=7), + nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7), nn.BatchNorm1d(bottleneck_dim), - nn.LeakyReLU(negative_slope=0.1, inplace=True) + nn.LeakyReLU(negative_slope=0.1, inplace=True), ) self.final = nn.Sequential( nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), - nn.Tanh() + nn.Tanh(), ) - - def forward( - self,waveform - ): + def forward(self, waveform): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1)!=1: - raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") + if waveform.size(1) != 1: + raise TypeError( + f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels" + ) encoder_outputs = [] out = waveform for encoder in self.encoders: out = encoder(out) - encoder_outputs.insert(0,out) - out = out[:,:,::2] - + encoder_outputs.insert(0, out) + out = out[:, :, ::2] + out = self.bottleneck(out) - for layer,decoder in enumerate(self.decoders): + for layer, decoder in enumerate(self.decoders): out = F.interpolate(out, scale_factor=2, mode="linear") - out = self.fix_last_dim(out,encoder_outputs[layer]) - out = torch.cat([out,encoder_outputs[layer]],dim=1) + out = self.fix_last_dim(out, encoder_outputs[layer]) + out = torch.cat([out, encoder_outputs[layer]], dim=1) out = decoder(out) - out = torch.cat([out, waveform],dim=1) + out = torch.cat([out, waveform], dim=1) out = self.final(out) return out - - def fix_last_dim(self,x,target): + + def fix_last_dim(self, x, target): """ - trying to do centre crop along last dimension + centre crop along last dimension """ - assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" + assert ( + x.shape[-1] >= target.shape[-1] + ), "input dimension cannot be larger than target dimension" if x.shape[-1] == target.shape[-1]: return x - + diff = x.shape[-1] - target.shape[-1] - if diff%2!=0: - x = F.pad(x,(0,1)) + if diff % 2 != 0: + x = F.pad(x, (0, 1)) diff += 1 - crop = diff//2 - return x[:,:,crop:-crop] + crop = diff // 2 + return x[:, :, crop:-crop]