wave u net encoder decoder

This commit is contained in:
shahules786 2022-09-23 18:21:15 +05:30
parent 0fa16054d9
commit 24c7a6f1f0
1 changed files with 51 additions and 12 deletions

View File

@ -1,4 +1,3 @@
from turtle import forward
import torch.nn as nn import torch.nn as nn
@ -8,14 +7,14 @@ class WavenetDecoder(nn.Module):
self, self,
in_channels:int, in_channels:int,
out_channels:int, out_channels:int,
kernel_size:int, kernel_size:int=5,
padding:int, padding:int=2,
stride:int, stride:int=1,
dilation:int=1, dilation:int=1,
): ):
super(WavenetDecoder,self).__init__() super(WavenetDecoder,self).__init__()
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding), nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1) nn.LeakyReLU(negative_slope=0.1)
) )
@ -30,13 +29,14 @@ class WavenetEncoder(nn.Module):
self, self,
in_channels:int, in_channels:int,
out_channels:int, out_channels:int,
kernel_size:int, kernel_size:int=15,
padding:int, padding:int=7,
stride:int, stride:int=1,
dilation:int=1, dilation:int=1,
): ):
super(WavenetEncoder,self).__init__()
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding), nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
nn.BatchNorm1d(out_channels), nn.BatchNorm1d(out_channels),
nn.LeakyReLU(negative_slope=0.1) nn.LeakyReLU(negative_slope=0.1)
) )
@ -54,11 +54,50 @@ class WavenetEncoder(nn.Module):
class WaveUnet(nn.Module): class WaveUnet(nn.Module):
def __init__( def __init__(
self self,
inp_channels:int=1,
num_layers:int=12,
initial_output_channels:int=24
): ):
pass super(WaveUnet,self).__init__()
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
out_channels = initial_output_channels
for layer in range(num_layers):
encoder = WavenetEncoder(inp_channels,out_channels)
self.encoders.append(encoder)
inp_channels = out_channels
out_channels += initial_output_channels
if layer == num_layers -1 :
decoder = WavenetDecoder(num_layers * initial_output_channels + inp_channels,inp_channels)
else:
decoder = WavenetDecoder(inp_channels+out_channels,inp_channels)
self.decoders.insert(0,decoder)
bottleneck_dim = num_layers * initial_output_channels
self.bottleneck = nn.Sequential(
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1,
padding=7),
nn.BatchNorm1d(bottleneck_dim),
nn.LeakyReLU(negative_slope=0.1, inplace=True)
)
def forward( def forward(
self,waveform self,waveform
): ):
pass
for encoder in self.encoders:
out = encoder(waveform)
out = self.bottleneck(out)
for decoder in self.decoders:
out = decoder(out)
return decoder