wave-u-net init
This commit is contained in:
parent
3f40b54fc6
commit
0fa16054d9
|
|
@ -0,0 +1,64 @@
|
|||
from turtle import forward
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class WavenetDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels:int,
|
||||
out_channels:int,
|
||||
kernel_size:int,
|
||||
padding:int,
|
||||
stride:int,
|
||||
dilation:int=1,
|
||||
):
|
||||
super(WavenetDecoder,self).__init__()
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
nn.LeakyReLU(negative_slope=0.1)
|
||||
)
|
||||
|
||||
def forward(self,waveform):
|
||||
|
||||
return self.decoder(waveform)
|
||||
|
||||
class WavenetEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels:int,
|
||||
out_channels:int,
|
||||
kernel_size:int,
|
||||
padding:int,
|
||||
stride:int,
|
||||
dilation:int=1,
|
||||
):
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
nn.LeakyReLU(negative_slope=0.1)
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
waveform
|
||||
):
|
||||
return self.encoder(waveform)
|
||||
|
||||
|
||||
|
||||
|
||||
class WaveUnet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,waveform
|
||||
):
|
||||
pass
|
||||
Loading…
Reference in New Issue