diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index bb20c3f..4a0e20f 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,5 +1,12 @@ +from tkinter import wantobjects +import wave +import torch import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Union, List +from enhancer.models.model import Model +from enhancer.data.dataset import EnhancerDataset class WavenetDecoder(nn.Module): @@ -49,32 +56,39 @@ class WavenetEncoder(nn.Module): return self.encoder(waveform) - - -class WaveUnet(nn.Module): +class WaveUnet(Model): def __init__( self, - inp_channels:int=1, + num_channels:int=1, num_layers:int=12, - initial_output_channels:int=24 + initial_output_channels:int=24, + sampling_rate:int=16000, + lr:float=1e-3, + dataset:Optional[EnhancerDataset]=None, + duration:Optional[float]=None, + loss: Union[str, List] = "mse", + metric:Union[str,List] = "mse" ): - super(WaveUnet,self).__init__() - + super().__init__(num_channels=num_channels, + sampling_rate=sampling_rate,lr=lr, + dataset=dataset,duration=duration,loss=loss, metric=metric + ) + self.save_hyperparameters("num_layers") self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() out_channels = initial_output_channels for layer in range(num_layers): - encoder = WavenetEncoder(inp_channels,out_channels) + encoder = WavenetEncoder(num_channels,out_channels) self.encoders.append(encoder) - inp_channels = out_channels + num_channels = out_channels out_channels += initial_output_channels if layer == num_layers -1 : - decoder = WavenetDecoder(num_layers * initial_output_channels + inp_channels,inp_channels) + decoder = WavenetDecoder(num_layers * initial_output_channels + num_channels,inp_channels) else: - decoder = WavenetDecoder(inp_channels+out_channels,inp_channels) + decoder = WavenetDecoder(num_channels+out_channels,num_channels) self.decoders.insert(0,decoder) @@ -85,19 +99,54 @@ class WaveUnet(nn.Module): nn.BatchNorm1d(bottleneck_dim), nn.LeakyReLU(negative_slope=0.1, inplace=True) ) - + self.final = nn.Sequential( + nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), + nn.Tanh() + ) def forward( self,waveform ): - - for encoder in self.encoders: - out = encoder(waveform) + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + if waveform.size(1)!=1: + raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") + + encoder_outputs = [] + out = waveform + for encoder in self.encoders: + out = encoder(out) + encoder_outputs.insert(0,out) + out = out[:,:,::2] + out = self.bottleneck(out) - for decoder in self.decoders: + for layer,decoder in enumerate(self.decoders): + out = F.interpolate(out, scale_factor=2, mode="linear") + print(out.shape,encoder_outputs[layer].shape) + out = self.fix_last_dim(out,encoder_outputs[layer]) + out = torch.cat([out,encoder_outputs[layer]],dim=1) out = decoder(out) - return decoder \ No newline at end of file + out = torch.cat([out, waveform],dim=1) + out = self.final(out) + return out + + def fix_last_dim(self,x,target): + """ + trying to do centre crop along last dimension + """ + + assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" + if x.shape[-1] == target.shape[-1]: + return x + + diff = x.shape[-1] - target.shape[-1] + if diff%2!=0: + x = F.pad(x,(0,1)) + diff += 1 + + crop = diff//2 + return x[:,:,crop:-crop]