mayavoz/enhancer/models/demucs.py

275 lines
8.4 KiB
Python

import logging
import math
from typing import List, Optional, Union
import torch.nn.functional as F
from torch import nn
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.model import Model
from enhancer.utils.io import Audio as audio
from enhancer.utils.utils import merge_dict
class DemucsLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bidirectional: bool = True,
):
super().__init__()
self.lstm = nn.LSTM(
input_size, hidden_size, num_layers, bidirectional=bidirectional
)
dim = 2 if bidirectional else 1
self.linear = nn.Linear(dim * hidden_size, hidden_size)
def forward(self, x):
output, (h, c) = self.lstm(x)
output = self.linear(output)
return output, (h, c)
class DemucsEncoder(nn.Module):
def __init__(
self,
num_channels: int,
hidden_size: int,
kernel_size: int,
stride: int = 1,
glu: bool = False,
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.encoder = nn.Sequential(
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DemucsDecoder(nn.Module):
def __init__(
self,
num_channels: int,
hidden_size: int,
kernel_size: int,
stride: int = 1,
glu: bool = False,
layer: int = 0,
):
super().__init__()
activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1
self.decoder = nn.Sequential(
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
activation,
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
)
if layer > 0:
self.decoder.add_module("4", nn.ReLU())
def forward(
self,
waveform,
):
out = self.decoder(waveform)
return out
class Demucs(Model):
"""
Demucs model from https://arxiv.org/pdf/1911.13254.pdf
parameters:
encoder_decoder: dict, optional
keyword arguments passsed to encoder decoder block
lstm : dict, optional
keyword arguments passsed to LSTM block
num_channels: int, defaults to 1
number channels in input audio
sampling_rate: int, defaults to 16KHz
sampling rate of input audio
lr : float, defaults to 1e-3
learning rate used for training
dataset: EnhancerDataset, optional
EnhancerDataset object containing train/validation data for training
duration : float, optional
chunk duration in seconds
loss : string or List of strings
loss function to be used, available ("mse","mae","SI-SDR")
metric : string or List of strings
metric function to be used, available ("mse","mae","SI-SDR")
"""
ED_DEFAULTS = {
"initial_output_channels": 48,
"kernel_size": 8,
"stride": 4,
"depth": 5,
"glu": True,
"growth_factor": 2,
}
LSTM_DEFAULTS = {
"bidirectional": True,
"num_layers": 2,
}
def __init__(
self,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
num_channels: int = 1,
resample: int = 4,
sampling_rate=16000,
normalize=True,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
floor=1e-3,
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"]
self.normalize = normalize
self.floor = floor
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for layer in range(encoder_decoder["depth"]):
encoder_layer = DemucsEncoder(
num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"],
glu=encoder_decoder["glu"],
)
self.encoder.append(encoder_layer)
decoder_layer = DemucsDecoder(
num_channels=num_channels,
hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"],
stride=encoder_decoder["stride"],
glu=encoder_decoder["glu"],
layer=layer,
)
self.decoder.insert(0, decoder_layer)
num_channels = hidden
hidden = self.ED_DEFAULTS["growth_factor"] * hidden
self.de_lstm = DemucsLSTM(
input_size=num_channels,
hidden_size=num_channels,
num_layers=lstm["num_layers"],
bidirectional=lstm["bidirectional"],
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True)
std = waveform.std(dim=-1, keepdim=True)
waveform = waveform / (self.floor + std)
else:
std = 1
length = waveform.shape[-1]
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample > 1:
x = audio.resample_audio(
audio=x,
sr=self.hparams.sampling_rate,
target_sr=int(
self.hparams.sampling_rate * self.hparams.resample
),
)
encoder_outputs = []
for encoder in self.encoder:
x = encoder(x)
encoder_outputs.append(x)
x = x.permute(0, 2, 1)
x, _ = self.de_lstm(x)
x = x.permute(0, 2, 1)
for decoder in self.decoder:
skip_connection = encoder_outputs.pop(-1)
x = x + skip_connection[..., : x.shape[-1]]
x = decoder(x)
if self.hparams.resample > 1:
x = audio.resample_audio(
x,
int(self.hparams.sampling_rate * self.hparams.resample),
self.hparams.sampling_rate,
)
out = x[..., :length]
return std * out
def get_padding_length(self, input_length):
input_length = math.ceil(input_length * self.hparams.resample)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # encoder operation
input_length = (
math.ceil(
(input_length - self.hparams.encoder_decoder["kernel_size"])
/ self.hparams.encoder_decoder["stride"]
)
+ 1
)
input_length = max(1, input_length)
for layer in range(
self.hparams.encoder_decoder["depth"]
): # decoder operaration
input_length = (input_length - 1) * self.hparams.encoder_decoder[
"stride"
] + self.hparams.encoder_decoder["kernel_size"]
input_length = math.ceil(input_length / self.hparams.resample)
return int(input_length)