mayavoz/enhancer/models/dccrn.py

331 lines
10 KiB
Python

import logging
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from enhancer.data import EnhancerDataset
from enhancer.models import Model
from enhancer.models.complexnn import (
ComplexBatchNorm2D,
ComplexConv2d,
ComplexConvTranspose2d,
ComplexLSTM,
ComplexRelu,
)
from enhancer.models.complexnn.utils import complex_cat
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
from enhancer.utils.utils import merge_dict
class DCCRN_ENCODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channel: int,
kernel_size: Tuple[int, int],
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 1),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
self.encoder = nn.Sequential(
ComplexConv2d(
in_channels,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
batchnorm(out_channel),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DCCRN_DECODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
layer: int = 0,
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 0),
output_padding: Tuple[int, int] = (1, 0),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
if layer != 0:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
),
batchnorm(out_channels),
activation,
)
else:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
)
)
def forward(self, waveform):
return self.decoder(waveform)
class DCCRN(Model):
STFT_DEFAULTS = {
"window_len": 400,
"hop_size": 100,
"nfft": 512,
"window": "hamming",
}
ED_DEFAULTS = {
"initial_output_channels": 32,
"depth": 6,
"kernel_size": 5,
"growth_factor": 2,
"stride": 2,
"padding": 2,
"output_padding": 1,
}
LSTM_DEFAULTS = {
"num_layers": 2,
"hidden_size": 256,
}
def __init__(
self,
stft: Optional[dict] = None,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
complex_lstm: bool = True,
complex_norm: bool = True,
complex_relu: bool = True,
masking_mode: str = "E",
num_channels: int = 1,
sampling_rate=16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
stft = merge_dict(self.STFT_DEFAULTS, stft)
self.save_hyperparameters(
"encoder_decoder",
"lstm",
"stft",
"complex_lstm",
"complex_norm",
"masking_mode",
)
self.complex_lstm = complex_lstm
self.complex_norm = complex_norm
self.masking_mode = masking_mode
self.stft = ConvSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.istft = ConviSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
num_channels *= 2
hidden_size = encoder_decoder["initial_output_channels"]
growth_factor = 2
for layer in range(encoder_decoder["depth"]):
encoder_ = DCCRN_ENCODER(
num_channels,
hidden_size,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 1),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.encoder.append(encoder_)
decoder_ = DCCRN_DECODER(
hidden_size + hidden_size,
num_channels,
layer=layer,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 0),
output_padding=(encoder_decoder["output_padding"], 0),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.decoder.insert(0, decoder_)
if layer < encoder_decoder["depth"] - 3:
num_channels = hidden_size
hidden_size *= growth_factor
else:
num_channels = hidden_size
kernel_size = hidden_size / 2
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
if self.complex_lstm:
lstms = []
for layer in range(lstm["num_layers"]):
if layer == 0:
input_size = int(hidden_size * kernel_size)
else:
input_size = lstm["hidden_size"]
if layer == lstm["num_layers"] - 1:
projection_size = int(hidden_size * kernel_size)
else:
projection_size = None
kwargs = {
"input_size": input_size,
"hidden_size": lstm["hidden_size"],
"num_layers": 1,
}
lstms.append(
ComplexLSTM(projection_size=projection_size, **kwargs)
)
self.lstm = nn.Sequential(*lstms)
else:
self.lstm = nn.Sequential(
nn.LSTM(
input_size=hidden_size * kernel_size,
hidden_sizs=lstm["hidden_size"],
num_layers=lstm["num_layers"],
dropout=0.0,
batch_first=False,
)[0],
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
)
def forward(self, waveform):
waveform_stft = self.stft(waveform)
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
phase_spec = torch.atan2(imag, real)
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
encoder_outputs = []
out = complex_spec
for _, encoder in enumerate(self.encoder):
out = encoder(out)
encoder_outputs.append(out)
B, C, D, T = out.size()
out = out.permute(3, 0, 1, 2)
if self.complex_lstm:
lstm_real = out[:, :, : C // 2]
lstm_imag = out[:, :, C // 2 :]
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
lstm_real = lstm_real.reshape(T, B, C // 2, D)
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
out = torch.cat([lstm_real, lstm_imag], 2)
else:
out = out.reshape(T, B, C * D)
out = self.lstm(out)
out = out.reshape(T, B, D, C)
out = out.permute(1, 2, 3, 0)
for layer, decoder in enumerate(self.decoder):
skip_connection = encoder_outputs.pop(-1)
out = complex_cat([skip_connection, out])
out = decoder(out)
out = out[..., 1:]
mask_real, mask_imag = out[:, 0], out[:, 1]
mask_real = F.pad(mask_real, [0, 0, 1, 0])
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
if self.masking_mode == "E":
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
real_phase = mask_real / (mask_mag + 1e-8)
imag_phase = mask_imag / (mask_mag + 1e-8)
mask_phase = torch.atan2(imag_phase, real_phase)
mask_mag = torch.tanh(mask_mag)
est_mag = mask_mag * mag_spec
est_phase = mask_phase * phase_spec
# cos(theta) + isin(theta)
real = est_mag + torch.cos(est_phase)
imag = est_mag + torch.sin(est_phase)
if self.masking_mode == "C":
real = real * mask_real - imag * mask_imag
imag = real * mask_imag + imag * mask_real
else:
real = real * mask_real
imag = imag * mask_imag
spec = torch.cat([real, imag], 1)
wav = self.istft(spec)
wav = wav.clamp_(-1, 1)
return wav