wave u net encoder decoder
This commit is contained in:
parent
0fa16054d9
commit
24c7a6f1f0
|
|
@ -1,4 +1,3 @@
|
|||
from turtle import forward
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
|
|
@ -8,14 +7,14 @@ class WavenetDecoder(nn.Module):
|
|||
self,
|
||||
in_channels:int,
|
||||
out_channels:int,
|
||||
kernel_size:int,
|
||||
padding:int,
|
||||
stride:int,
|
||||
kernel_size:int=5,
|
||||
padding:int=2,
|
||||
stride:int=1,
|
||||
dilation:int=1,
|
||||
):
|
||||
super(WavenetDecoder,self).__init__()
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
nn.LeakyReLU(negative_slope=0.1)
|
||||
)
|
||||
|
|
@ -30,13 +29,14 @@ class WavenetEncoder(nn.Module):
|
|||
self,
|
||||
in_channels:int,
|
||||
out_channels:int,
|
||||
kernel_size:int,
|
||||
padding:int,
|
||||
stride:int,
|
||||
kernel_size:int=15,
|
||||
padding:int=7,
|
||||
stride:int=1,
|
||||
dilation:int=1,
|
||||
):
|
||||
super(WavenetEncoder,self).__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
nn.LeakyReLU(negative_slope=0.1)
|
||||
)
|
||||
|
|
@ -54,11 +54,50 @@ class WavenetEncoder(nn.Module):
|
|||
class WaveUnet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self
|
||||
self,
|
||||
inp_channels:int=1,
|
||||
num_layers:int=12,
|
||||
initial_output_channels:int=24
|
||||
):
|
||||
pass
|
||||
super(WaveUnet,self).__init__()
|
||||
|
||||
self.encoders = nn.ModuleList()
|
||||
self.decoders = nn.ModuleList()
|
||||
out_channels = initial_output_channels
|
||||
for layer in range(num_layers):
|
||||
|
||||
encoder = WavenetEncoder(inp_channels,out_channels)
|
||||
self.encoders.append(encoder)
|
||||
|
||||
inp_channels = out_channels
|
||||
out_channels += initial_output_channels
|
||||
if layer == num_layers -1 :
|
||||
decoder = WavenetDecoder(num_layers * initial_output_channels + inp_channels,inp_channels)
|
||||
else:
|
||||
decoder = WavenetDecoder(inp_channels+out_channels,inp_channels)
|
||||
|
||||
self.decoders.insert(0,decoder)
|
||||
|
||||
bottleneck_dim = num_layers * initial_output_channels
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1,
|
||||
padding=7),
|
||||
nn.BatchNorm1d(bottleneck_dim),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
)
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,waveform
|
||||
):
|
||||
pass
|
||||
|
||||
for encoder in self.encoders:
|
||||
out = encoder(waveform)
|
||||
|
||||
out = self.bottleneck(out)
|
||||
|
||||
for decoder in self.decoders:
|
||||
out = decoder(out)
|
||||
|
||||
return decoder
|
||||
Loading…
Reference in New Issue