refactor models
This commit is contained in:
parent
2cf9803ed1
commit
451058c29d
|
|
@ -9,209 +9,255 @@ from enhancer.data.dataset import EnhancerDataset
|
||||||
from enhancer.utils.io import Audio as audio
|
from enhancer.utils.io import Audio as audio
|
||||||
from enhancer.utils.utils import merge_dict
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
class DemucsLSTM(nn.Module):
|
class DemucsLSTM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_size:int,
|
input_size: int,
|
||||||
hidden_size:int,
|
hidden_size: int,
|
||||||
num_layers:int,
|
num_layers: int,
|
||||||
bidirectional:bool=True
|
bidirectional: bool = True,
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
self.lstm = nn.LSTM(
|
||||||
|
input_size, hidden_size, num_layers, bidirectional=bidirectional
|
||||||
|
)
|
||||||
dim = 2 if bidirectional else 1
|
dim = 2 if bidirectional else 1
|
||||||
self.linear = nn.Linear(dim*hidden_size,hidden_size)
|
self.linear = nn.Linear(dim * hidden_size, hidden_size)
|
||||||
|
|
||||||
def forward(self,x):
|
def forward(self, x):
|
||||||
|
|
||||||
output,(h,c) = self.lstm(x)
|
output, (h, c) = self.lstm(x)
|
||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
|
|
||||||
return output,(h,c)
|
return output, (h, c)
|
||||||
|
|
||||||
|
|
||||||
class DemucsEncoder(nn.Module):
|
class DemucsEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int,
|
num_channels: int,
|
||||||
hidden_size:int,
|
hidden_size: int,
|
||||||
kernel_size:int,
|
kernel_size: int,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
glu:bool=False,
|
glu: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
activation = nn.GLU(1) if glu else nn.ReLU()
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
multi_factor = 2 if glu else 1
|
multi_factor = 2 if glu else 1
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
nn.Conv1d(num_channels,hidden_size,kernel_size,stride),
|
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden_size, hidden_size*multi_factor,kernel_size,1),
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||||
activation
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,waveform):
|
def forward(self, waveform):
|
||||||
|
|
||||||
return self.encoder(waveform)
|
return self.encoder(waveform)
|
||||||
|
|
||||||
class DemucsDecoder(nn.Module):
|
|
||||||
|
|
||||||
|
class DemucsDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int,
|
num_channels: int,
|
||||||
hidden_size:int,
|
hidden_size: int,
|
||||||
kernel_size:int,
|
kernel_size: int,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
glu:bool=False,
|
glu: bool = False,
|
||||||
layer:int=0
|
layer: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
activation = nn.GLU(1) if glu else nn.ReLU()
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
multi_factor = 2 if glu else 1
|
multi_factor = 2 if glu else 1
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv1d(hidden_size,hidden_size*multi_factor,kernel_size,1),
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||||
activation,
|
activation,
|
||||||
nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride)
|
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
|
||||||
)
|
)
|
||||||
if layer>0:
|
if layer > 0:
|
||||||
self.decoder.add_module("4", nn.ReLU())
|
self.decoder.add_module("4", nn.ReLU())
|
||||||
|
|
||||||
def forward(self,waveform,):
|
def forward(
|
||||||
|
self,
|
||||||
|
waveform,
|
||||||
|
):
|
||||||
|
|
||||||
out = self.decoder(waveform)
|
out = self.decoder(waveform)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Demucs(Model):
|
class Demucs(Model):
|
||||||
|
"""
|
||||||
|
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
|
||||||
|
parameters:
|
||||||
|
encoder_decoder: dict, optional
|
||||||
|
keyword arguments passsed to encoder decoder block
|
||||||
|
lstm : dict, optional
|
||||||
|
keyword arguments passsed to LSTM block
|
||||||
|
num_channels: int, defaults to 1
|
||||||
|
number channels in input audio
|
||||||
|
sampling_rate: int, defaults to 16KHz
|
||||||
|
sampling rate of input audio
|
||||||
|
lr : float, defaults to 1e-3
|
||||||
|
learning rate used for training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
EnhancerDataset object containing train/validation data for training
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds
|
||||||
|
loss : string or List of strings
|
||||||
|
loss function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
metric : string or List of strings
|
||||||
|
metric function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
ED_DEFAULTS = {
|
ED_DEFAULTS = {
|
||||||
"initial_output_channels":48,
|
"initial_output_channels": 48,
|
||||||
"kernel_size":8,
|
"kernel_size": 8,
|
||||||
"stride":1,
|
"stride": 1,
|
||||||
"depth":5,
|
"depth": 5,
|
||||||
"glu":True,
|
"glu": True,
|
||||||
"growth_factor":2,
|
"growth_factor": 2,
|
||||||
}
|
}
|
||||||
LSTM_DEFAULTS = {
|
LSTM_DEFAULTS = {
|
||||||
"bidirectional":True,
|
"bidirectional": True,
|
||||||
"num_layers":2,
|
"num_layers": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_decoder:Optional[dict]=None,
|
encoder_decoder: Optional[dict] = None,
|
||||||
lstm:Optional[dict]=None,
|
lstm: Optional[dict] = None,
|
||||||
num_channels:int=1,
|
num_channels: int = 1,
|
||||||
resample:int=4,
|
resample: int = 4,
|
||||||
sampling_rate = 16000,
|
sampling_rate=16000,
|
||||||
lr:float=1e-3,
|
lr: float = 1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
loss:Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str, List] = "mse"
|
metric: Union[str, List] = "mse",
|
||||||
|
|
||||||
|
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate!=dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
logging.warn(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
num_channels=num_channels,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
|
dataset=dataset,
|
||||||
|
duration=duration,
|
||||||
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
|
)
|
||||||
|
|
||||||
encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
|
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
self.save_hyperparameters("encoder_decoder","lstm","resample")
|
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||||
hidden = encoder_decoder["initial_output_channels"]
|
hidden = encoder_decoder["initial_output_channels"]
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
for layer in range(encoder_decoder["depth"]):
|
for layer in range(encoder_decoder["depth"]):
|
||||||
|
|
||||||
encoder_layer = DemucsEncoder(num_channels=num_channels,
|
encoder_layer = DemucsEncoder(
|
||||||
hidden_size=hidden,
|
num_channels=num_channels,
|
||||||
kernel_size=encoder_decoder["kernel_size"],
|
hidden_size=hidden,
|
||||||
stride=encoder_decoder["stride"],
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
glu=encoder_decoder["glu"],
|
stride=encoder_decoder["stride"],
|
||||||
)
|
glu=encoder_decoder["glu"],
|
||||||
|
)
|
||||||
self.encoder.append(encoder_layer)
|
self.encoder.append(encoder_layer)
|
||||||
|
|
||||||
decoder_layer = DemucsDecoder(num_channels=num_channels,
|
decoder_layer = DemucsDecoder(
|
||||||
hidden_size=hidden,
|
num_channels=num_channels,
|
||||||
kernel_size=encoder_decoder["kernel_size"],
|
hidden_size=hidden,
|
||||||
stride=1,
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
glu=encoder_decoder["glu"],
|
stride=1,
|
||||||
layer=layer
|
glu=encoder_decoder["glu"],
|
||||||
)
|
layer=layer,
|
||||||
self.decoder.insert(0,decoder_layer)
|
)
|
||||||
|
self.decoder.insert(0, decoder_layer)
|
||||||
|
|
||||||
num_channels = hidden
|
num_channels = hidden
|
||||||
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
|
||||||
|
|
||||||
self.de_lstm = DemucsLSTM(input_size=num_channels,
|
self.de_lstm = DemucsLSTM(
|
||||||
hidden_size=num_channels,
|
input_size=num_channels,
|
||||||
num_layers=lstm["num_layers"],
|
hidden_size=num_channels,
|
||||||
bidirectional=lstm["bidirectional"]
|
num_layers=lstm["num_layers"],
|
||||||
)
|
bidirectional=lstm["bidirectional"],
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self,waveform):
|
def forward(self, waveform):
|
||||||
|
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform = waveform.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if waveform.size(1)!=1:
|
if waveform.size(1) != 1:
|
||||||
raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels")
|
raise TypeError(
|
||||||
|
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
length = waveform.shape[-1]
|
length = waveform.shape[-1]
|
||||||
x = F.pad(waveform, (0,self.get_padding_length(length) - length))
|
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||||
if self.hparams.resample>1:
|
if self.hparams.resample > 1:
|
||||||
x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate,
|
x = audio.resample_audio(
|
||||||
target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
|
audio=x,
|
||||||
|
sr=self.hparams.sampling_rate,
|
||||||
|
target_sr=int(
|
||||||
|
self.hparams.sampling_rate * self.hparams.resample
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for encoder in self.encoder:
|
for encoder in self.encoder:
|
||||||
x = encoder(x)
|
x = encoder(x)
|
||||||
encoder_outputs.append(x)
|
encoder_outputs.append(x)
|
||||||
x = x.permute(0,2,1)
|
x = x.permute(0, 2, 1)
|
||||||
x,_ = self.de_lstm(x)
|
x, _ = self.de_lstm(x)
|
||||||
|
|
||||||
x = x.permute(0,2,1)
|
x = x.permute(0, 2, 1)
|
||||||
for decoder in self.decoder:
|
for decoder in self.decoder:
|
||||||
skip_connection = encoder_outputs.pop(-1)
|
skip_connection = encoder_outputs.pop(-1)
|
||||||
x += skip_connection[..., :x.shape[-1]]
|
x += skip_connection[..., : x.shape[-1]]
|
||||||
x = decoder(x)
|
x = decoder(x)
|
||||||
|
|
||||||
if self.hparams.resample > 1:
|
if self.hparams.resample > 1:
|
||||||
x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample),
|
x = audio.resample_audio(
|
||||||
self.hparams.sampling_rate)
|
x,
|
||||||
|
int(self.hparams.sampling_rate * self.hparams.resample),
|
||||||
|
self.hparams.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_padding_length(self,input_length):
|
def get_padding_length(self, input_length):
|
||||||
|
|
||||||
input_length = math.ceil(input_length * self.hparams.resample)
|
input_length = math.ceil(input_length * self.hparams.resample)
|
||||||
|
|
||||||
|
for layer in range(
|
||||||
for layer in range(self.hparams.encoder_decoder["depth"]): # encoder operation
|
self.hparams.encoder_decoder["depth"]
|
||||||
input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1
|
): # encoder operation
|
||||||
input_length = max(1,input_length)
|
input_length = (
|
||||||
for layer in range(self.hparams.encoder_decoder["depth"]): # decoder operaration
|
math.ceil(
|
||||||
input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"]
|
(input_length - self.hparams.encoder_decoder["kernel_size"])
|
||||||
input_length = math.ceil(input_length/self.hparams.resample)
|
/ self.hparams.encoder_decoder["stride"]
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
input_length = max(1, input_length)
|
||||||
|
for layer in range(
|
||||||
|
self.hparams.encoder_decoder["depth"]
|
||||||
|
): # decoder operaration
|
||||||
|
input_length = (input_length - 1) * self.hparams.encoder_decoder[
|
||||||
|
"stride"
|
||||||
|
] + self.hparams.encoder_decoder["kernel_size"]
|
||||||
|
input_length = math.ceil(input_length / self.hparams.resample)
|
||||||
|
|
||||||
return int(input_length)
|
return int(input_length)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -7,76 +7,117 @@ from typing import Optional, Union, List
|
||||||
from enhancer.models.model import Model
|
from enhancer.models.model import Model
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
||||||
class WavenetDecoder(nn.Module):
|
|
||||||
|
|
||||||
|
class WavenetDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels:int,
|
in_channels: int,
|
||||||
out_channels:int,
|
out_channels: int,
|
||||||
kernel_size:int=5,
|
kernel_size: int = 5,
|
||||||
padding:int=2,
|
padding: int = 2,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
dilation:int=1,
|
dilation: int = 1,
|
||||||
):
|
):
|
||||||
super(WavenetDecoder,self).__init__()
|
super(WavenetDecoder, self).__init__()
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.BatchNorm1d(out_channels),
|
||||||
nn.LeakyReLU(negative_slope=0.1)
|
nn.LeakyReLU(negative_slope=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self,waveform):
|
def forward(self, waveform):
|
||||||
|
|
||||||
return self.decoder(waveform)
|
return self.decoder(waveform)
|
||||||
|
|
||||||
class WavenetEncoder(nn.Module):
|
|
||||||
|
|
||||||
|
class WavenetEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels:int,
|
in_channels: int,
|
||||||
out_channels:int,
|
out_channels: int,
|
||||||
kernel_size:int=15,
|
kernel_size: int = 15,
|
||||||
padding:int=7,
|
padding: int = 7,
|
||||||
stride:int=1,
|
stride: int = 1,
|
||||||
dilation:int=1,
|
dilation: int = 1,
|
||||||
):
|
):
|
||||||
super(WavenetEncoder,self).__init__()
|
super(WavenetEncoder, self).__init__()
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation),
|
nn.Conv1d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
),
|
||||||
nn.BatchNorm1d(out_channels),
|
nn.BatchNorm1d(out_channels),
|
||||||
nn.LeakyReLU(negative_slope=0.1)
|
nn.LeakyReLU(negative_slope=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
waveform
|
|
||||||
):
|
|
||||||
return self.encoder(waveform)
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
class WaveUnet(Model):
|
class WaveUnet(Model):
|
||||||
|
"""
|
||||||
|
Wave-U-Net model from https://arxiv.org/pdf/1811.11307.pdf
|
||||||
|
parameters:
|
||||||
|
num_channels: int, defaults to 1
|
||||||
|
number of channels in input audio
|
||||||
|
depth : int, defaults to 12
|
||||||
|
depth of network
|
||||||
|
initial_output_channels: int, defaults to 24
|
||||||
|
number of output channels in intial upsampling layer
|
||||||
|
sampling_rate: int, defaults to 16KHz
|
||||||
|
sampling rate of input audio
|
||||||
|
lr : float, defaults to 1e-3
|
||||||
|
learning rate used for training
|
||||||
|
dataset: EnhancerDataset, optional
|
||||||
|
EnhancerDataset object containing train/validation data for training
|
||||||
|
duration : float, optional
|
||||||
|
chunk duration in seconds
|
||||||
|
loss : string or List of strings
|
||||||
|
loss function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
metric : string or List of strings
|
||||||
|
metric function to be used, available ("mse","mae","SI-SDR")
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels:int=1,
|
num_channels: int = 1,
|
||||||
depth:int=12,
|
depth: int = 12,
|
||||||
initial_output_channels:int=24,
|
initial_output_channels: int = 24,
|
||||||
sampling_rate:int=16000,
|
sampling_rate: int = 16000,
|
||||||
lr:float=1e-3,
|
lr: float = 1e-3,
|
||||||
dataset:Optional[EnhancerDataset]=None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration:Optional[float]=None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str,List] = "mse"
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate!=dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
logging.warn(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
num_channels=num_channels,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
|
dataset=dataset,
|
||||||
|
duration=duration,
|
||||||
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
)
|
)
|
||||||
self.save_hyperparameters("depth")
|
self.save_hyperparameters("depth")
|
||||||
self.encoders = nn.ModuleList()
|
self.encoders = nn.ModuleList()
|
||||||
|
|
@ -84,72 +125,76 @@ class WaveUnet(Model):
|
||||||
out_channels = initial_output_channels
|
out_channels = initial_output_channels
|
||||||
for layer in range(depth):
|
for layer in range(depth):
|
||||||
|
|
||||||
encoder = WavenetEncoder(num_channels,out_channels)
|
encoder = WavenetEncoder(num_channels, out_channels)
|
||||||
self.encoders.append(encoder)
|
self.encoders.append(encoder)
|
||||||
|
|
||||||
num_channels = out_channels
|
num_channels = out_channels
|
||||||
out_channels += initial_output_channels
|
out_channels += initial_output_channels
|
||||||
if layer == depth -1 :
|
if layer == depth - 1:
|
||||||
decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels)
|
decoder = WavenetDecoder(
|
||||||
|
depth * initial_output_channels + num_channels, num_channels
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoder = WavenetDecoder(num_channels+out_channels,num_channels)
|
decoder = WavenetDecoder(
|
||||||
|
num_channels + out_channels, num_channels
|
||||||
|
)
|
||||||
|
|
||||||
self.decoders.insert(0,decoder)
|
self.decoders.insert(0, decoder)
|
||||||
|
|
||||||
bottleneck_dim = depth * initial_output_channels
|
bottleneck_dim = depth * initial_output_channels
|
||||||
self.bottleneck = nn.Sequential(
|
self.bottleneck = nn.Sequential(
|
||||||
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1,
|
nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7),
|
||||||
padding=7),
|
|
||||||
nn.BatchNorm1d(bottleneck_dim),
|
nn.BatchNorm1d(bottleneck_dim),
|
||||||
nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||||
)
|
)
|
||||||
self.final = nn.Sequential(
|
self.final = nn.Sequential(
|
||||||
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
|
nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1),
|
||||||
nn.Tanh()
|
nn.Tanh(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
def forward(
|
|
||||||
self,waveform
|
|
||||||
):
|
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform = waveform.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if waveform.size(1)!=1:
|
if waveform.size(1) != 1:
|
||||||
raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels")
|
raise TypeError(
|
||||||
|
f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
out = waveform
|
out = waveform
|
||||||
for encoder in self.encoders:
|
for encoder in self.encoders:
|
||||||
out = encoder(out)
|
out = encoder(out)
|
||||||
encoder_outputs.insert(0,out)
|
encoder_outputs.insert(0, out)
|
||||||
out = out[:,:,::2]
|
out = out[:, :, ::2]
|
||||||
|
|
||||||
out = self.bottleneck(out)
|
out = self.bottleneck(out)
|
||||||
|
|
||||||
for layer,decoder in enumerate(self.decoders):
|
for layer, decoder in enumerate(self.decoders):
|
||||||
out = F.interpolate(out, scale_factor=2, mode="linear")
|
out = F.interpolate(out, scale_factor=2, mode="linear")
|
||||||
out = self.fix_last_dim(out,encoder_outputs[layer])
|
out = self.fix_last_dim(out, encoder_outputs[layer])
|
||||||
out = torch.cat([out,encoder_outputs[layer]],dim=1)
|
out = torch.cat([out, encoder_outputs[layer]], dim=1)
|
||||||
out = decoder(out)
|
out = decoder(out)
|
||||||
|
|
||||||
out = torch.cat([out, waveform],dim=1)
|
out = torch.cat([out, waveform], dim=1)
|
||||||
out = self.final(out)
|
out = self.final(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def fix_last_dim(self,x,target):
|
def fix_last_dim(self, x, target):
|
||||||
"""
|
"""
|
||||||
trying to do centre crop along last dimension
|
centre crop along last dimension
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension"
|
assert (
|
||||||
|
x.shape[-1] >= target.shape[-1]
|
||||||
|
), "input dimension cannot be larger than target dimension"
|
||||||
if x.shape[-1] == target.shape[-1]:
|
if x.shape[-1] == target.shape[-1]:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
diff = x.shape[-1] - target.shape[-1]
|
diff = x.shape[-1] - target.shape[-1]
|
||||||
if diff%2!=0:
|
if diff % 2 != 0:
|
||||||
x = F.pad(x,(0,1))
|
x = F.pad(x, (0, 1))
|
||||||
diff += 1
|
diff += 1
|
||||||
|
|
||||||
crop = diff//2
|
crop = diff // 2
|
||||||
return x[:,:,crop:-crop]
|
return x[:, :, crop:-crop]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue