refactor arguments
This commit is contained in:
parent
4edc90deb0
commit
79311444d8
|
|
@ -1,8 +1,12 @@
|
||||||
|
from typing import Optional
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from enhancer.utils.io import Audio as audio
|
from enhancer.utils.io import Audio as audio
|
||||||
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
class DeLSTM(nn.Module):
|
class DeLSTM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -25,64 +29,68 @@ class DeLSTM(nn.Module):
|
||||||
|
|
||||||
return output,(h,c)
|
return output,(h,c)
|
||||||
|
|
||||||
class Demucs(nn.Module):
|
class Demucs(Model):
|
||||||
|
|
||||||
|
ED_DEFAULTS = {
|
||||||
|
"intial_output_channels":48,
|
||||||
|
"kernel_size":8,
|
||||||
|
"stride":1,
|
||||||
|
"depth":5,
|
||||||
|
"glu":True,
|
||||||
|
"growth_factor":2,
|
||||||
|
}
|
||||||
|
LSTM_DEFAULTS = {
|
||||||
|
"bidirectional":True,
|
||||||
|
"num_layers":2,
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
c_in:int=1,
|
encoder_decoder:Optional[dict]=None,
|
||||||
c_out:int=1,
|
lstm:Optional[dict]=None,
|
||||||
hidden:int=48,
|
num_channels:int=1,
|
||||||
kernel_size:int=8,
|
|
||||||
stride:int=4,
|
|
||||||
growth_factor:int=2,
|
|
||||||
depth:int = 5,
|
|
||||||
glu:bool = True,
|
|
||||||
bidirectional:bool=True,
|
|
||||||
resample:int=4,
|
resample:int=4,
|
||||||
sampling_rate = 16000
|
sampling_rate = 16000,
|
||||||
|
dataset:Optional[EnhancerDataset]=None,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(num_channels=num_channels,
|
||||||
self.c_in = c_in
|
sampling_rate=sampling_rate,dataset=dataset)
|
||||||
self.c_out = c_out
|
|
||||||
self.hidden = hidden
|
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
||||||
self.growth_factor = growth_factor
|
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
||||||
self.stride = stride
|
self.save_hyperparameters("encoder_decoder","lstm","resample")
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.depth = depth
|
hidden = encoder_decoder["initial_channel_output"]
|
||||||
self.bidirectional = bidirectional
|
activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU()
|
||||||
self.activation = nn.GLU(1) if glu else nn.ReLU()
|
multi_factor = 2 if encoder_decoder["glu"] else 1
|
||||||
self.resample = resample
|
|
||||||
self.sampling_rate = sampling_rate
|
|
||||||
multi_factor = 2 if glu else 1
|
|
||||||
|
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
for layer in range(self.depth):
|
for layer in range(encoder_decoder["depth"]):
|
||||||
|
|
||||||
encoder_layer = [nn.Conv1d(c_in,hidden,kernel_size,stride),
|
encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1),
|
nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1),
|
||||||
self.activation]
|
activation]
|
||||||
encoder_layer = nn.Sequential(*encoder_layer)
|
encoder_layer = nn.Sequential(*encoder_layer)
|
||||||
self.encoder.append(encoder_layer)
|
self.encoder.append(encoder_layer)
|
||||||
|
|
||||||
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1),
|
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1),
|
||||||
self.activation,
|
activation,
|
||||||
nn.ConvTranspose1d(hidden,c_out,kernel_size,stride)
|
nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"])
|
||||||
]
|
]
|
||||||
if layer>0:
|
if layer>0:
|
||||||
decoder_layer.append(nn.ReLU())
|
decoder_layer.append(nn.ReLU())
|
||||||
decoder_layer = nn.Sequential(*decoder_layer)
|
decoder_layer = nn.Sequential(*decoder_layer)
|
||||||
self.decoder.insert(0,decoder_layer)
|
self.decoder.insert(0,decoder_layer)
|
||||||
|
|
||||||
c_out = hidden
|
num_channels = hidden
|
||||||
c_in = hidden
|
|
||||||
hidden = self.growth_factor * hidden
|
hidden = self.growth_factor * hidden
|
||||||
|
|
||||||
|
|
||||||
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
|
self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"])
|
||||||
|
|
||||||
def forward(self,mixed_signal):
|
def forward(self,mixed_signal):
|
||||||
|
|
||||||
|
|
@ -91,14 +99,13 @@ class Demucs(nn.Module):
|
||||||
|
|
||||||
length = mixed_signal.shape[-1]
|
length = mixed_signal.shape[-1]
|
||||||
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
|
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
|
||||||
if self.resample>1:
|
if self.hparams.resample>1:
|
||||||
x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate,
|
x = audio.pt_resample_audio(audio=x, sr=self.hparams.sampling_rate,
|
||||||
target_sr=int(self.sampling_rate * self.resample))
|
target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
|
||||||
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for encoder in self.encoder:
|
for encoder in self.encoder:
|
||||||
x = encoder(x)
|
x = encoder(x)
|
||||||
print(x.shape)
|
|
||||||
encoder_outputs.append(x)
|
encoder_outputs.append(x)
|
||||||
x = x.permute(0,2,1)
|
x = x.permute(0,2,1)
|
||||||
x,_ = self.de_lstm(x)
|
x,_ = self.de_lstm(x)
|
||||||
|
|
@ -109,23 +116,23 @@ class Demucs(nn.Module):
|
||||||
x += skip_connection[..., :x.shape[-1]]
|
x += skip_connection[..., :x.shape[-1]]
|
||||||
x = decoder(x)
|
x = decoder(x)
|
||||||
|
|
||||||
if self.resample > 1:
|
if self.hparams.resample > 1:
|
||||||
x = audio.pt_resample_audio(x,int(self.sampling_rate * self.resample),
|
x = audio.pt_resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample),
|
||||||
self.sampling_rate)
|
self.hparams.sampling_rate)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_padding_length(self,input_length):
|
def get_padding_length(self,input_length):
|
||||||
|
|
||||||
input_length = math.ceil(input_length * self.resample)
|
input_length = math.ceil(input_length * self.hparams.resample)
|
||||||
|
|
||||||
|
|
||||||
for layer in range(self.depth): # encoder operation
|
for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation
|
||||||
input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1
|
input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1
|
||||||
input_length = max(1,input_length)
|
input_length = max(1,input_length)
|
||||||
for layer in range(self.depth): # decoder operaration
|
for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration
|
||||||
input_length = (input_length-1) * self.stride + self.kernel_size
|
input_length = (input_length-1) * self.stride + self.kernel_size
|
||||||
input_length = math.ceil(input_length/self.resample)
|
input_length = math.ceil(input_length/self.hparams.resample)
|
||||||
|
|
||||||
return int(input_length)
|
return int(input_length)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue