refactor encoder-decoder
This commit is contained in:
parent
8a90899663
commit
3f40b54fc6
|
|
@ -1,3 +1,5 @@
|
||||||
|
from base64 import encode
|
||||||
|
from turtle import forward
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -8,7 +10,7 @@ from enhancer.data.dataset import EnhancerDataset
|
||||||
from enhancer.utils.io import Audio as audio
|
from enhancer.utils.io import Audio as audio
|
||||||
from enhancer.utils.utils import merge_dict
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
class DeLSTM(nn.Module):
|
class DemucsLSTM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size:int,
|
input_size:int,
|
||||||
|
|
@ -29,6 +31,59 @@ class DeLSTM(nn.Module):
|
||||||
|
|
||||||
return output,(h,c)
|
return output,(h,c)
|
||||||
|
|
||||||
|
|
||||||
|
class DemucsEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels:int,
|
||||||
|
hidden_size:int,
|
||||||
|
kernel_size:int,
|
||||||
|
stride:int=1,
|
||||||
|
glu:bool=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
|
multi_factor = 2 if glu else 1
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv1d(num_channels,hidden_size,kernel_size,stride),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1),
|
||||||
|
activation
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,waveform):
|
||||||
|
|
||||||
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
class DemucsDecoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_channels:int,
|
||||||
|
hidden_size:int,
|
||||||
|
kernel_size:int,
|
||||||
|
stride:int=1,
|
||||||
|
glu:bool=False,
|
||||||
|
layer:int=0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
|
multi_factor = 2 if glu else 1
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1),
|
||||||
|
activation,
|
||||||
|
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride)
|
||||||
|
)
|
||||||
|
if layer>0:
|
||||||
|
self.decoder.add_module("4", nn.ReLU())
|
||||||
|
|
||||||
|
def forward(self,waveform,):
|
||||||
|
|
||||||
|
out = self.decoder(waveform)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Demucs(Model):
|
class Demucs(Model):
|
||||||
|
|
||||||
ED_DEFAULTS = {
|
ED_DEFAULTS = {
|
||||||
|
|
@ -56,44 +111,45 @@ class Demucs(Model):
|
||||||
loss:Union[str, List] = "mse"
|
loss:Union[str, List] = "mse"
|
||||||
|
|
||||||
):
|
):
|
||||||
|
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(num_channels=num_channels,
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
sampling_rate=sampling_rate,lr=lr,
|
||||||
dataset=dataset,loss=loss)
|
dataset=dataset,duration=duration,loss=loss)
|
||||||
|
|
||||||
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
||||||
self.save_hyperparameters("encoder_decoder","lstm","resample")
|
self.save_hyperparameters("encoder_decoder","lstm","resample")
|
||||||
|
|
||||||
hidden = encoder_decoder["initial_output_channels"]
|
hidden = encoder_decoder["initial_output_channels"]
|
||||||
activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU()
|
|
||||||
multi_factor = 2 if encoder_decoder["glu"] else 1
|
|
||||||
|
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
for layer in range(encoder_decoder["depth"]):
|
for layer in range(encoder_decoder["depth"]):
|
||||||
|
|
||||||
encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]),
|
encoder_layer = DemucsEncoder(num_channels=num_channels,
|
||||||
nn.ReLU(),
|
hidden_size=hidden,
|
||||||
nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1),
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
activation]
|
stride=encoder_decoder["stride"],
|
||||||
encoder_layer = nn.Sequential(*encoder_layer)
|
glu=encoder_decoder["glu"],
|
||||||
|
)
|
||||||
self.encoder.append(encoder_layer)
|
self.encoder.append(encoder_layer)
|
||||||
|
|
||||||
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1),
|
decoder_layer = DemucsDecoder(num_channels=num_channels,
|
||||||
activation,
|
hidden_size=hidden,
|
||||||
nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"])
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
]
|
stride=1,
|
||||||
if layer>0:
|
glu=encoder_decoder["glu"],
|
||||||
decoder_layer.append(nn.ReLU())
|
layer=layer
|
||||||
decoder_layer = nn.Sequential(*decoder_layer)
|
)
|
||||||
self.decoder.insert(0,decoder_layer)
|
self.decoder.insert(0,decoder_layer)
|
||||||
|
|
||||||
num_channels = hidden
|
num_channels = hidden
|
||||||
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
||||||
|
|
||||||
|
|
||||||
self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"])
|
self.de_lstm = DemucsLSTM(input_size=num_channels,
|
||||||
|
hidden_size=num_channels,
|
||||||
|
num_layers=lstm["num_layers"],
|
||||||
|
bidirectional=lstm["bidirectional"]
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self,mixed_signal):
|
def forward(self,mixed_signal):
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue