refactor arguments

This commit is contained in:
shahules786 2022-09-07 13:10:30 +05:30
parent 4edc90deb0
commit 79311444d8
1 changed files with 52 additions and 45 deletions

View File

@ -1,8 +1,12 @@
from typing import Optional
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset
from enhancer.utils.io import Audio as audio from enhancer.utils.io import Audio as audio
from enhancer.utils.utils import merge_dict
class DeLSTM(nn.Module): class DeLSTM(nn.Module):
def __init__( def __init__(
@ -25,64 +29,68 @@ class DeLSTM(nn.Module):
return output,(h,c) return output,(h,c)
class Demucs(nn.Module): class Demucs(Model):
ED_DEFAULTS = {
"intial_output_channels":48,
"kernel_size":8,
"stride":1,
"depth":5,
"glu":True,
"growth_factor":2,
}
LSTM_DEFAULTS = {
"bidirectional":True,
"num_layers":2,
}
def __init__( def __init__(
self, self,
c_in:int=1, encoder_decoder:Optional[dict]=None,
c_out:int=1, lstm:Optional[dict]=None,
hidden:int=48, num_channels:int=1,
kernel_size:int=8,
stride:int=4,
growth_factor:int=2,
depth:int = 5,
glu:bool = True,
bidirectional:bool=True,
resample:int=4, resample:int=4,
sampling_rate = 16000 sampling_rate = 16000,
dataset:Optional[EnhancerDataset]=None,
): ):
super().__init__() super().__init__(num_channels=num_channels,
self.c_in = c_in sampling_rate=sampling_rate,dataset=dataset)
self.c_out = c_out
self.hidden = hidden encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
self.growth_factor = growth_factor lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
self.stride = stride self.save_hyperparameters("encoder_decoder","lstm","resample")
self.kernel_size = kernel_size
self.depth = depth hidden = encoder_decoder["initial_channel_output"]
self.bidirectional = bidirectional activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU()
self.activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if encoder_decoder["glu"] else 1
self.resample = resample
self.sampling_rate = sampling_rate
multi_factor = 2 if glu else 1
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
for layer in range(self.depth): for layer in range(encoder_decoder["depth"]):
encoder_layer = [nn.Conv1d(c_in,hidden,kernel_size,stride), encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1), nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1),
self.activation] activation]
encoder_layer = nn.Sequential(*encoder_layer) encoder_layer = nn.Sequential(*encoder_layer)
self.encoder.append(encoder_layer) self.encoder.append(encoder_layer)
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1), decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1),
self.activation, activation,
nn.ConvTranspose1d(hidden,c_out,kernel_size,stride) nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"])
] ]
if layer>0: if layer>0:
decoder_layer.append(nn.ReLU()) decoder_layer.append(nn.ReLU())
decoder_layer = nn.Sequential(*decoder_layer) decoder_layer = nn.Sequential(*decoder_layer)
self.decoder.insert(0,decoder_layer) self.decoder.insert(0,decoder_layer)
c_out = hidden num_channels = hidden
c_in = hidden
hidden = self.growth_factor * hidden hidden = self.growth_factor * hidden
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"])
def forward(self,mixed_signal): def forward(self,mixed_signal):
@ -91,14 +99,13 @@ class Demucs(nn.Module):
length = mixed_signal.shape[-1] length = mixed_signal.shape[-1]
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
if self.resample>1: if self.hparams.resample>1:
x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate, x = audio.pt_resample_audio(audio=x, sr=self.hparams.sampling_rate,
target_sr=int(self.sampling_rate * self.resample)) target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
encoder_outputs = [] encoder_outputs = []
for encoder in self.encoder: for encoder in self.encoder:
x = encoder(x) x = encoder(x)
print(x.shape)
encoder_outputs.append(x) encoder_outputs.append(x)
x = x.permute(0,2,1) x = x.permute(0,2,1)
x,_ = self.de_lstm(x) x,_ = self.de_lstm(x)
@ -109,23 +116,23 @@ class Demucs(nn.Module):
x += skip_connection[..., :x.shape[-1]] x += skip_connection[..., :x.shape[-1]]
x = decoder(x) x = decoder(x)
if self.resample > 1: if self.hparams.resample > 1:
x = audio.pt_resample_audio(x,int(self.sampling_rate * self.resample), x = audio.pt_resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample),
self.sampling_rate) self.hparams.sampling_rate)
return x return x
def get_padding_length(self,input_length): def get_padding_length(self,input_length):
input_length = math.ceil(input_length * self.resample) input_length = math.ceil(input_length * self.hparams.resample)
for layer in range(self.depth): # encoder operation for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation
input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1 input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1
input_length = max(1,input_length) input_length = max(1,input_length)
for layer in range(self.depth): # decoder operaration for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration
input_length = (input_length-1) * self.stride + self.kernel_size input_length = (input_length-1) * self.stride + self.kernel_size
input_length = math.ceil(input_length/self.resample) input_length = math.ceil(input_length/self.hparams.resample)
return int(input_length) return int(input_length)