from tkinter import wantobjects import wave import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Union, List from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset class WavenetDecoder(nn.Module): def __init__( self, in_channels:int, out_channels:int, kernel_size:int=5, padding:int=2, stride:int=1, dilation:int=1, ): super(WavenetDecoder,self).__init__() self.decoder = nn.Sequential( nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.BatchNorm1d(out_channels), nn.LeakyReLU(negative_slope=0.1) ) def forward(self,waveform): return self.decoder(waveform) class WavenetEncoder(nn.Module): def __init__( self, in_channels:int, out_channels:int, kernel_size:int=15, padding:int=7, stride:int=1, dilation:int=1, ): super(WavenetEncoder,self).__init__() self.encoder = nn.Sequential( nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.BatchNorm1d(out_channels), nn.LeakyReLU(negative_slope=0.1) ) def forward( self, waveform ): return self.encoder(waveform) class WaveUnet(Model): def __init__( self, num_channels:int=1, num_layers:int=12, initial_output_channels:int=24, sampling_rate:int=16000, lr:float=1e-3, dataset:Optional[EnhancerDataset]=None, duration:Optional[float]=None, loss: Union[str, List] = "mse", metric:Union[str,List] = "mse" ): super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric ) self.save_hyperparameters("num_layers") self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() out_channels = initial_output_channels for layer in range(num_layers): encoder = WavenetEncoder(num_channels,out_channels) self.encoders.append(encoder) num_channels = out_channels out_channels += initial_output_channels if layer == num_layers -1 : decoder = WavenetDecoder(num_layers * initial_output_channels + num_channels,inp_channels) else: decoder = WavenetDecoder(num_channels+out_channels,num_channels) self.decoders.insert(0,decoder) bottleneck_dim = num_layers * initial_output_channels self.bottleneck = nn.Sequential( nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, padding=7), nn.BatchNorm1d(bottleneck_dim), nn.LeakyReLU(negative_slope=0.1, inplace=True) ) self.final = nn.Sequential( nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), nn.Tanh() ) def forward( self,waveform ): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) if waveform.size(1)!=1: raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") encoder_outputs = [] out = waveform for encoder in self.encoders: out = encoder(out) encoder_outputs.insert(0,out) out = out[:,:,::2] out = self.bottleneck(out) for layer,decoder in enumerate(self.decoders): out = F.interpolate(out, scale_factor=2, mode="linear") print(out.shape,encoder_outputs[layer].shape) out = self.fix_last_dim(out,encoder_outputs[layer]) out = torch.cat([out,encoder_outputs[layer]],dim=1) out = decoder(out) out = torch.cat([out, waveform],dim=1) out = self.final(out) return out def fix_last_dim(self,x,target): """ trying to do centre crop along last dimension """ assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" if x.shape[-1] == target.shape[-1]: return x diff = x.shape[-1] - target.shape[-1] if diff%2!=0: x = F.pad(x,(0,1)) diff += 1 crop = diff//2 return x[:,:,crop:-crop]