diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py new file mode 100644 index 0000000..2a9d083 --- /dev/null +++ b/enhancer/models/waveunet.py @@ -0,0 +1,64 @@ +from turtle import forward +import torch.nn as nn + + +class WavenetDecoder(nn.Module): + + def __init__( + self, + in_channels:int, + out_channels:int, + kernel_size:int, + padding:int, + stride:int, + dilation:int=1, + ): + super(WavenetDecoder,self).__init__() + self.decoder = nn.Sequential( + nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding), + nn.BatchNorm1d(out_channels), + nn.LeakyReLU(negative_slope=0.1) + ) + + def forward(self,waveform): + + return self.decoder(waveform) + +class WavenetEncoder(nn.Module): + + def __init__( + self, + in_channels:int, + out_channels:int, + kernel_size:int, + padding:int, + stride:int, + dilation:int=1, + ): + self.encoder = nn.Sequential( + nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding), + nn.BatchNorm1d(out_channels), + nn.LeakyReLU(negative_slope=0.1) + ) + + + def forward( + self, + waveform + ): + return self.encoder(waveform) + + + + +class WaveUnet(nn.Module): + + def __init__( + self + ): + pass + + def forward( + self,waveform + ): + pass \ No newline at end of file