339 lines
11 KiB
Python
339 lines
11 KiB
Python
import logging
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from enhancer.data import EnhancerDataset
|
|
from enhancer.models import Model
|
|
from enhancer.models.complexnn import (
|
|
ComplexBatchNorm2D,
|
|
ComplexConv2d,
|
|
ComplexConvTranspose2d,
|
|
ComplexLSTM,
|
|
ComplexRelu,
|
|
)
|
|
from enhancer.models.complexnn.utils import complex_cat
|
|
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
|
from enhancer.utils.utils import merge_dict
|
|
|
|
|
|
class DCCRN_ENCODER(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channel: int,
|
|
kernel_size: Tuple[int, int],
|
|
complex_norm: bool = True,
|
|
complex_relu: bool = True,
|
|
stride: Tuple[int, int] = (2, 1),
|
|
padding: Tuple[int, int] = (2, 1),
|
|
):
|
|
super().__init__()
|
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
|
|
|
self.encoder = nn.Sequential(
|
|
ComplexConv2d(
|
|
in_channels,
|
|
out_channel,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
),
|
|
batchnorm(out_channel),
|
|
activation,
|
|
)
|
|
|
|
def forward(self, waveform):
|
|
|
|
return self.encoder(waveform)
|
|
|
|
|
|
class DCCRN_DECODER(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Tuple[int, int],
|
|
layer: int = 0,
|
|
complex_norm: bool = True,
|
|
complex_relu: bool = True,
|
|
stride: Tuple[int, int] = (2, 1),
|
|
padding: Tuple[int, int] = (2, 0),
|
|
output_padding: Tuple[int, int] = (1, 0),
|
|
):
|
|
super().__init__()
|
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
|
|
|
if layer != 0:
|
|
self.decoder = nn.Sequential(
|
|
ComplexConvTranspose2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
output_padding=output_padding,
|
|
),
|
|
batchnorm(out_channels),
|
|
activation,
|
|
)
|
|
else:
|
|
self.decoder = nn.Sequential(
|
|
ComplexConvTranspose2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
output_padding=output_padding,
|
|
)
|
|
)
|
|
|
|
def forward(self, waveform):
|
|
|
|
return self.decoder(waveform)
|
|
|
|
|
|
class DCCRN(Model):
|
|
|
|
STFT_DEFAULTS = {
|
|
"window_len": 400,
|
|
"hop_size": 100,
|
|
"nfft": 512,
|
|
"window": "hamming",
|
|
}
|
|
|
|
ED_DEFAULTS = {
|
|
"initial_output_channels": 32,
|
|
"depth": 6,
|
|
"kernel_size": 5,
|
|
"growth_factor": 2,
|
|
"stride": 2,
|
|
"padding": 2,
|
|
"output_padding": 1,
|
|
}
|
|
|
|
LSTM_DEFAULTS = {
|
|
"num_layers": 2,
|
|
"hidden_size": 256,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
stft: Optional[dict] = None,
|
|
encoder_decoder: Optional[dict] = None,
|
|
lstm: Optional[dict] = None,
|
|
complex_lstm: bool = True,
|
|
complex_norm: bool = True,
|
|
complex_relu: bool = True,
|
|
masking_mode: str = "E",
|
|
num_channels: int = 1,
|
|
sampling_rate=16000,
|
|
lr: float = 1e-3,
|
|
dataset: Optional[EnhancerDataset] = None,
|
|
duration: Optional[float] = None,
|
|
loss: Union[str, List, Any] = "mse",
|
|
metric: Union[str, List] = "mse",
|
|
):
|
|
duration = (
|
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
|
)
|
|
if dataset is not None:
|
|
if sampling_rate != dataset.sampling_rate:
|
|
logging.warning(
|
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
|
)
|
|
sampling_rate = dataset.sampling_rate
|
|
super().__init__(
|
|
num_channels=num_channels,
|
|
sampling_rate=sampling_rate,
|
|
lr=lr,
|
|
dataset=dataset,
|
|
duration=duration,
|
|
loss=loss,
|
|
metric=metric,
|
|
)
|
|
|
|
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
|
stft = merge_dict(self.STFT_DEFAULTS, stft)
|
|
self.save_hyperparameters(
|
|
"encoder_decoder",
|
|
"lstm",
|
|
"stft",
|
|
"complex_lstm",
|
|
"complex_norm",
|
|
"masking_mode",
|
|
)
|
|
self.complex_lstm = complex_lstm
|
|
self.complex_norm = complex_norm
|
|
self.masking_mode = masking_mode
|
|
|
|
self.stft = ConvSTFT(
|
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
|
)
|
|
self.istft = ConviSTFT(
|
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
|
)
|
|
|
|
self.encoder = nn.ModuleList()
|
|
self.decoder = nn.ModuleList()
|
|
|
|
num_channels *= 2
|
|
hidden_size = encoder_decoder["initial_output_channels"]
|
|
growth_factor = 2
|
|
|
|
for layer in range(encoder_decoder["depth"]):
|
|
|
|
encoder_ = DCCRN_ENCODER(
|
|
num_channels,
|
|
hidden_size,
|
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
|
stride=(encoder_decoder["stride"], 1),
|
|
padding=(encoder_decoder["padding"], 1),
|
|
complex_norm=complex_norm,
|
|
complex_relu=complex_relu,
|
|
)
|
|
self.encoder.append(encoder_)
|
|
|
|
decoder_ = DCCRN_DECODER(
|
|
hidden_size + hidden_size,
|
|
num_channels,
|
|
layer=layer,
|
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
|
stride=(encoder_decoder["stride"], 1),
|
|
padding=(encoder_decoder["padding"], 0),
|
|
output_padding=(encoder_decoder["output_padding"], 0),
|
|
complex_norm=complex_norm,
|
|
complex_relu=complex_relu,
|
|
)
|
|
|
|
self.decoder.insert(0, decoder_)
|
|
|
|
if layer < encoder_decoder["depth"] - 3:
|
|
num_channels = hidden_size
|
|
hidden_size *= growth_factor
|
|
else:
|
|
num_channels = hidden_size
|
|
|
|
kernel_size = hidden_size / 2
|
|
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
|
|
|
|
if self.complex_lstm:
|
|
lstms = []
|
|
for layer in range(lstm["num_layers"]):
|
|
|
|
if layer == 0:
|
|
input_size = int(hidden_size * kernel_size)
|
|
else:
|
|
input_size = lstm["hidden_size"]
|
|
|
|
if layer == lstm["num_layers"] - 1:
|
|
projection_size = int(hidden_size * kernel_size)
|
|
else:
|
|
projection_size = None
|
|
|
|
kwargs = {
|
|
"input_size": input_size,
|
|
"hidden_size": lstm["hidden_size"],
|
|
"num_layers": 1,
|
|
}
|
|
|
|
lstms.append(
|
|
ComplexLSTM(projection_size=projection_size, **kwargs)
|
|
)
|
|
self.lstm = nn.Sequential(*lstms)
|
|
else:
|
|
self.lstm = nn.Sequential(
|
|
nn.LSTM(
|
|
input_size=hidden_size * kernel_size,
|
|
hidden_sizs=lstm["hidden_size"],
|
|
num_layers=lstm["num_layers"],
|
|
dropout=0.0,
|
|
batch_first=False,
|
|
)[0],
|
|
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
|
|
)
|
|
|
|
def forward(self, waveform):
|
|
|
|
if waveform.dim() == 2:
|
|
waveform = waveform.unsqueeze(1)
|
|
|
|
if waveform.size(1) != self.hparams.num_channels:
|
|
raise ValueError(
|
|
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
|
)
|
|
|
|
waveform_stft = self.stft(waveform)
|
|
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
|
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
|
|
|
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
|
|
phase_spec = torch.atan2(imag, real)
|
|
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
|
|
|
|
encoder_outputs = []
|
|
out = complex_spec
|
|
for _, encoder in enumerate(self.encoder):
|
|
out = encoder(out)
|
|
encoder_outputs.append(out)
|
|
|
|
B, C, D, T = out.size()
|
|
out = out.permute(3, 0, 1, 2)
|
|
if self.complex_lstm:
|
|
|
|
lstm_real = out[:, :, : C // 2]
|
|
lstm_imag = out[:, :, C // 2 :]
|
|
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
|
|
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
|
|
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
|
|
lstm_real = lstm_real.reshape(T, B, C // 2, D)
|
|
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
|
|
out = torch.cat([lstm_real, lstm_imag], 2)
|
|
else:
|
|
out = out.reshape(T, B, C * D)
|
|
out = self.lstm(out)
|
|
out = out.reshape(T, B, D, C)
|
|
|
|
out = out.permute(1, 2, 3, 0)
|
|
for layer, decoder in enumerate(self.decoder):
|
|
skip_connection = encoder_outputs.pop(-1)
|
|
out = complex_cat([skip_connection, out])
|
|
out = decoder(out)
|
|
out = out[..., 1:]
|
|
mask_real, mask_imag = out[:, 0], out[:, 1]
|
|
mask_real = F.pad(mask_real, [0, 0, 1, 0])
|
|
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
|
|
if self.masking_mode == "E":
|
|
|
|
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
|
|
real_phase = mask_real / (mask_mag + 1e-8)
|
|
imag_phase = mask_imag / (mask_mag + 1e-8)
|
|
mask_phase = torch.atan2(imag_phase, real_phase)
|
|
mask_mag = torch.tanh(mask_mag)
|
|
est_mag = mask_mag * mag_spec
|
|
est_phase = mask_phase * phase_spec
|
|
# cos(theta) + isin(theta)
|
|
real = est_mag + torch.cos(est_phase)
|
|
imag = est_mag + torch.sin(est_phase)
|
|
|
|
if self.masking_mode == "C":
|
|
|
|
real = real * mask_real - imag * mask_imag
|
|
imag = real * mask_imag + imag * mask_real
|
|
|
|
else:
|
|
|
|
real = real * mask_real
|
|
imag = imag * mask_imag
|
|
|
|
spec = torch.cat([real, imag], 1)
|
|
wav = self.istft(spec)
|
|
wav = wav.clamp_(-1, 1)
|
|
return wav
|