mayavoz/enhancer/models/dccrn.py

241 lines
7.2 KiB
Python

import logging
from typing import Any, List, Optional, Tuple, Union
from torch import nn
from enhancer.data import EnhancerDataset
from enhancer.models import Model
from enhancer.models.complexnn import ComplexConv2d, ComplexLSTM
from enhancer.models.complexnn.conv import ComplexConvTranspose2d
from enhancer.models.complexnn.utils import ComplexBatchNorm2D, ComplexRelu
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
from enhancer.utils.utils import merge_dict
class DCCRN_ENCODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channel: int,
kernel_size: Tuple[int, int],
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 1),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
self.encoder = nn.Sequential(
ComplexConv2d(
in_channels,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
batchnorm(out_channel),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DCCRN_DECODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 0),
output_padding: Tuple[int, int] = (1, 0),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
),
batchnorm(out_channels),
activation,
)
def forward(self, waveform):
return self.decoder(waveform)
class DCCRN(Model):
STFT_DEFAULTS = {
"window_len": 400,
"hop_size": 100,
"nfft": 512,
"window": "hamming",
}
ED_DEFAULTS = {
"initial_output_channels": 32,
"depth": 6,
"kernel_size": 5,
"growth_factor": 2,
"stride": 2,
"padding": 2,
"output_padding": 1,
}
LSTM_DEFAULTS = {
"num_layers": 2,
"hidden_size": 256,
}
def __init__(
self,
stft: Optional[dict] = None,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
complex_lstm: bool = True,
complex_norm: bool = True,
complex_relu: bool = True,
masking_mode: str = "E",
num_channels: int = 1,
sampling_rate=16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
stft = merge_dict(self.STFT_DEFAULTS, stft)
self.save_hyperparameters(
"encoder_decoder",
"lstm",
"stft",
"complex_lstm",
"complex_norm",
"masking_mode",
)
self.complex_lstm = complex_lstm
self.complex_norm = complex_norm
self.masking_mode = masking_mode
self.stft = ConvSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.istft = ConviSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
num_channels *= 2
hidden_size = encoder_decoder["initial_output_channels"]
growth_factor = 2
for layer in range(encoder_decoder["depth"]):
encoder_ = DCCRN_ENCODER(
num_channels,
hidden_size,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 1),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.encoder.append(encoder_)
decoder_ = DCCRN_DECODER(
hidden_size + hidden_size,
num_channels,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 0),
output_padding=(encoder_decoder["output_padding"], 0),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.decoder.insert(0, decoder_)
if layer < encoder_decoder["depth"] - 3:
num_channels = hidden_size
hidden_size *= growth_factor
else:
num_channels = hidden_size
kernel_size = hidden_size / 2
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
if self.complex_lstm:
lstms = []
for layer in range(lstm["num_layers"]):
if layer == 0:
input_size = int(hidden_size * kernel_size)
else:
input_size = lstm["hidden_size"]
if layer == lstm["num_layers"] - 1:
projection_size = int(hidden_size * kernel_size)
else:
projection_size = None
kwargs = {
"input_size": input_size,
"hidden_size": lstm["hidden_size"],
"num_layers": 1,
}
lstms.append(
ComplexLSTM(projection_size=projection_size, **kwargs)
)
self.lstm = nn.Sequential(*lstms)
else:
self.lstm = nn.LSTM(
input_size=hidden_size * kernel_size,
hidden_sizs=lstm["hidden_size"],
num_layers=lstm["num_layers"],
dropout=0.0,
batch_first=False,
)
def forward(self, waveform):
return waveform