diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 2a9d083..bb20c3f 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,4 +1,3 @@ -from turtle import forward import torch.nn as nn @@ -8,14 +7,14 @@ class WavenetDecoder(nn.Module): self, in_channels:int, out_channels:int, - kernel_size:int, - padding:int, - stride:int, + kernel_size:int=5, + padding:int=2, + stride:int=1, dilation:int=1, ): super(WavenetDecoder,self).__init__() self.decoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding), + nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.BatchNorm1d(out_channels), nn.LeakyReLU(negative_slope=0.1) ) @@ -30,13 +29,14 @@ class WavenetEncoder(nn.Module): self, in_channels:int, out_channels:int, - kernel_size:int, - padding:int, - stride:int, + kernel_size:int=15, + padding:int=7, + stride:int=1, dilation:int=1, ): + super(WavenetEncoder,self).__init__() self.encoder = nn.Sequential( - nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding), + nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), nn.BatchNorm1d(out_channels), nn.LeakyReLU(negative_slope=0.1) ) @@ -54,11 +54,50 @@ class WavenetEncoder(nn.Module): class WaveUnet(nn.Module): def __init__( - self + self, + inp_channels:int=1, + num_layers:int=12, + initial_output_channels:int=24 ): - pass + super(WaveUnet,self).__init__() + + self.encoders = nn.ModuleList() + self.decoders = nn.ModuleList() + out_channels = initial_output_channels + for layer in range(num_layers): + + encoder = WavenetEncoder(inp_channels,out_channels) + self.encoders.append(encoder) + + inp_channels = out_channels + out_channels += initial_output_channels + if layer == num_layers -1 : + decoder = WavenetDecoder(num_layers * initial_output_channels + inp_channels,inp_channels) + else: + decoder = WavenetDecoder(inp_channels+out_channels,inp_channels) + + self.decoders.insert(0,decoder) + + bottleneck_dim = num_layers * initial_output_channels + self.bottleneck = nn.Sequential( + nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, + padding=7), + nn.BatchNorm1d(bottleneck_dim), + nn.LeakyReLU(negative_slope=0.1, inplace=True) + ) + + def forward( self,waveform ): - pass \ No newline at end of file + + for encoder in self.encoders: + out = encoder(waveform) + + out = self.bottleneck(out) + + for decoder in self.decoders: + out = decoder(out) + + return decoder \ No newline at end of file