refactor encoder-decoder

This commit is contained in:
shahules786 2022-09-21 10:36:56 +05:30
parent 8a90899663
commit 3f40b54fc6
1 changed files with 76 additions and 20 deletions

View File

@ -1,3 +1,5 @@
from base64 import encode
from turtle import forward
from typing import Optional, Union, List from typing import Optional, Union, List
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@ -8,7 +10,7 @@ from enhancer.data.dataset import EnhancerDataset
from enhancer.utils.io import Audio as audio from enhancer.utils.io import Audio as audio
from enhancer.utils.utils import merge_dict from enhancer.utils.utils import merge_dict
class DeLSTM(nn.Module): class DemucsLSTM(nn.Module):
def __init__( def __init__(
self, self,
input_size:int, input_size:int,
@ -29,6 +31,59 @@ class DeLSTM(nn.Module):
return output,(h,c) return output,(h,c)
class DemucsEncoder(nn.Module):
def __init__(
self,
num_channels:int,
hidden_size:int,
kernel_size:int,
stride:int=1,
glu:bool=False,
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.encoder = nn.Sequential(
nn.Conv1d(num_channels,hidden_size,kernel_size,stride),
nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1),
activation
)
def forward(self,waveform):
return self.encoder(waveform)
class DemucsDecoder(nn.Module):
def __init__(
self,
num_channels:int,
hidden_size:int,
kernel_size:int,
stride:int=1,
glu:bool=False,
layer:int=0
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.decoder = nn.Sequential(
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1),
activation,
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride)
)
if layer>0:
self.decoder.add_module("4", nn.ReLU())
def forward(self,waveform,):
out = self.decoder(waveform)
return out
class Demucs(Model): class Demucs(Model):
ED_DEFAULTS = { ED_DEFAULTS = {
@ -56,44 +111,45 @@ class Demucs(Model):
loss:Union[str, List] = "mse" loss:Union[str, List] = "mse"
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
super().__init__(num_channels=num_channels, super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr, sampling_rate=sampling_rate,lr=lr,
dataset=dataset,loss=loss) dataset=dataset,duration=duration,loss=loss)
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS,lstm) lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
self.save_hyperparameters("encoder_decoder","lstm","resample") self.save_hyperparameters("encoder_decoder","lstm","resample")
hidden = encoder_decoder["initial_output_channels"] hidden = encoder_decoder["initial_output_channels"]
activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU()
multi_factor = 2 if encoder_decoder["glu"] else 1
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
for layer in range(encoder_decoder["depth"]): for layer in range(encoder_decoder["depth"]):
encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]), encoder_layer = DemucsEncoder(num_channels=num_channels,
nn.ReLU(), hidden_size=hidden,
nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1), kernel_size=encoder_decoder["kernel_size"],
activation] stride=encoder_decoder["stride"],
encoder_layer = nn.Sequential(*encoder_layer) glu=encoder_decoder["glu"],
)
self.encoder.append(encoder_layer) self.encoder.append(encoder_layer)
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1), decoder_layer = DemucsDecoder(num_channels=num_channels,
activation, hidden_size=hidden,
nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"]) kernel_size=encoder_decoder["kernel_size"],
] stride=1,
if layer>0: glu=encoder_decoder["glu"],
decoder_layer.append(nn.ReLU()) layer=layer
decoder_layer = nn.Sequential(*decoder_layer) )
self.decoder.insert(0,decoder_layer) self.decoder.insert(0,decoder_layer)
num_channels = hidden num_channels = hidden
hidden = self.ED_DEFAULTS["growth_factor"] * hidden hidden = self.ED_DEFAULTS["growth_factor"] * hidden
self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"]) self.de_lstm = DemucsLSTM(input_size=num_channels,
hidden_size=num_channels,
num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"]
)
def forward(self,mixed_signal): def forward(self,mixed_signal):