241 lines
7.2 KiB
Python
241 lines
7.2 KiB
Python
import logging
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
|
|
from torch import nn
|
|
|
|
from enhancer.data import EnhancerDataset
|
|
from enhancer.models import Model
|
|
from enhancer.models.complexnn import ComplexConv2d, ComplexLSTM
|
|
from enhancer.models.complexnn.conv import ComplexConvTranspose2d
|
|
from enhancer.models.complexnn.utils import ComplexBatchNorm2D, ComplexRelu
|
|
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
|
from enhancer.utils.utils import merge_dict
|
|
|
|
|
|
class DCCRN_ENCODER(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channel: int,
|
|
kernel_size: Tuple[int, int],
|
|
complex_norm: bool = True,
|
|
complex_relu: bool = True,
|
|
stride: Tuple[int, int] = (2, 1),
|
|
padding: Tuple[int, int] = (2, 1),
|
|
):
|
|
super().__init__()
|
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
|
|
|
self.encoder = nn.Sequential(
|
|
ComplexConv2d(
|
|
in_channels,
|
|
out_channel,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
),
|
|
batchnorm(out_channel),
|
|
activation,
|
|
)
|
|
|
|
def forward(self, waveform):
|
|
|
|
return self.encoder(waveform)
|
|
|
|
|
|
class DCCRN_DECODER(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Tuple[int, int],
|
|
complex_norm: bool = True,
|
|
complex_relu: bool = True,
|
|
stride: Tuple[int, int] = (2, 1),
|
|
padding: Tuple[int, int] = (2, 0),
|
|
output_padding: Tuple[int, int] = (1, 0),
|
|
):
|
|
super().__init__()
|
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
|
|
|
self.decoder = nn.Sequential(
|
|
ComplexConvTranspose2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
output_padding=output_padding,
|
|
),
|
|
batchnorm(out_channels),
|
|
activation,
|
|
)
|
|
|
|
def forward(self, waveform):
|
|
|
|
return self.decoder(waveform)
|
|
|
|
|
|
class DCCRN(Model):
|
|
|
|
STFT_DEFAULTS = {
|
|
"window_len": 400,
|
|
"hop_size": 100,
|
|
"nfft": 512,
|
|
"window": "hamming",
|
|
}
|
|
|
|
ED_DEFAULTS = {
|
|
"initial_output_channels": 32,
|
|
"depth": 6,
|
|
"kernel_size": 5,
|
|
"growth_factor": 2,
|
|
"stride": 2,
|
|
"padding": 2,
|
|
"output_padding": 1,
|
|
}
|
|
|
|
LSTM_DEFAULTS = {
|
|
"num_layers": 2,
|
|
"hidden_size": 256,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
stft: Optional[dict] = None,
|
|
encoder_decoder: Optional[dict] = None,
|
|
lstm: Optional[dict] = None,
|
|
complex_lstm: bool = True,
|
|
complex_norm: bool = True,
|
|
complex_relu: bool = True,
|
|
masking_mode: str = "E",
|
|
num_channels: int = 1,
|
|
sampling_rate=16000,
|
|
lr: float = 1e-3,
|
|
dataset: Optional[EnhancerDataset] = None,
|
|
duration: Optional[float] = None,
|
|
loss: Union[str, List, Any] = "mse",
|
|
metric: Union[str, List] = "mse",
|
|
):
|
|
duration = (
|
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
|
)
|
|
if dataset is not None:
|
|
if sampling_rate != dataset.sampling_rate:
|
|
logging.warning(
|
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
|
)
|
|
sampling_rate = dataset.sampling_rate
|
|
super().__init__(
|
|
num_channels=num_channels,
|
|
sampling_rate=sampling_rate,
|
|
lr=lr,
|
|
dataset=dataset,
|
|
duration=duration,
|
|
loss=loss,
|
|
metric=metric,
|
|
)
|
|
|
|
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
|
stft = merge_dict(self.STFT_DEFAULTS, stft)
|
|
self.save_hyperparameters(
|
|
"encoder_decoder",
|
|
"lstm",
|
|
"stft",
|
|
"complex_lstm",
|
|
"complex_norm",
|
|
"masking_mode",
|
|
)
|
|
self.complex_lstm = complex_lstm
|
|
self.complex_norm = complex_norm
|
|
self.masking_mode = masking_mode
|
|
|
|
self.stft = ConvSTFT(
|
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
|
)
|
|
self.istft = ConviSTFT(
|
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
|
)
|
|
|
|
self.encoder = nn.ModuleList()
|
|
self.decoder = nn.ModuleList()
|
|
|
|
num_channels *= 2
|
|
hidden_size = encoder_decoder["initial_output_channels"]
|
|
growth_factor = 2
|
|
|
|
for layer in range(encoder_decoder["depth"]):
|
|
|
|
encoder_ = DCCRN_ENCODER(
|
|
num_channels,
|
|
hidden_size,
|
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
|
stride=(encoder_decoder["stride"], 1),
|
|
padding=(encoder_decoder["padding"], 1),
|
|
complex_norm=complex_norm,
|
|
complex_relu=complex_relu,
|
|
)
|
|
self.encoder.append(encoder_)
|
|
|
|
decoder_ = DCCRN_DECODER(
|
|
hidden_size + hidden_size,
|
|
num_channels,
|
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
|
stride=(encoder_decoder["stride"], 1),
|
|
padding=(encoder_decoder["padding"], 0),
|
|
output_padding=(encoder_decoder["output_padding"], 0),
|
|
complex_norm=complex_norm,
|
|
complex_relu=complex_relu,
|
|
)
|
|
|
|
self.decoder.insert(0, decoder_)
|
|
|
|
if layer < encoder_decoder["depth"] - 3:
|
|
num_channels = hidden_size
|
|
hidden_size *= growth_factor
|
|
else:
|
|
num_channels = hidden_size
|
|
|
|
kernel_size = hidden_size / 2
|
|
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
|
|
|
|
if self.complex_lstm:
|
|
lstms = []
|
|
for layer in range(lstm["num_layers"]):
|
|
|
|
if layer == 0:
|
|
input_size = int(hidden_size * kernel_size)
|
|
else:
|
|
input_size = lstm["hidden_size"]
|
|
|
|
if layer == lstm["num_layers"] - 1:
|
|
projection_size = int(hidden_size * kernel_size)
|
|
else:
|
|
projection_size = None
|
|
|
|
kwargs = {
|
|
"input_size": input_size,
|
|
"hidden_size": lstm["hidden_size"],
|
|
"num_layers": 1,
|
|
}
|
|
|
|
lstms.append(
|
|
ComplexLSTM(projection_size=projection_size, **kwargs)
|
|
)
|
|
self.lstm = nn.Sequential(*lstms)
|
|
else:
|
|
self.lstm = nn.LSTM(
|
|
input_size=hidden_size * kernel_size,
|
|
hidden_sizs=lstm["hidden_size"],
|
|
num_layers=lstm["num_layers"],
|
|
dropout=0.0,
|
|
batch_first=False,
|
|
)
|
|
|
|
def forward(self, waveform):
|
|
|
|
return waveform
|