diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index e69de29..38c8145 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -0,0 +1,240 @@ +import logging +from typing import Any, List, Optional, Tuple, Union + +from torch import nn + +from enhancer.data import EnhancerDataset +from enhancer.models import Model +from enhancer.models.complexnn import ComplexConv2d, ComplexLSTM +from enhancer.models.complexnn.conv import ComplexConvTranspose2d +from enhancer.models.complexnn.utils import ComplexBatchNorm2D, ComplexRelu +from enhancer.utils.transforms import ConviSTFT, ConvSTFT +from enhancer.utils.utils import merge_dict + + +class DCCRN_ENCODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channel: int, + kernel_size: Tuple[int, int], + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 1), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + self.encoder = nn.Sequential( + ComplexConv2d( + in_channels, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + batchnorm(out_channel), + activation, + ) + + def forward(self, waveform): + + return self.encoder(waveform) + + +class DCCRN_DECODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 0), + output_padding: Tuple[int, int] = (1, 0), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + batchnorm(out_channels), + activation, + ) + + def forward(self, waveform): + + return self.decoder(waveform) + + +class DCCRN(Model): + + STFT_DEFAULTS = { + "window_len": 400, + "hop_size": 100, + "nfft": 512, + "window": "hamming", + } + + ED_DEFAULTS = { + "initial_output_channels": 32, + "depth": 6, + "kernel_size": 5, + "growth_factor": 2, + "stride": 2, + "padding": 2, + "output_padding": 1, + } + + LSTM_DEFAULTS = { + "num_layers": 2, + "hidden_size": 256, + } + + def __init__( + self, + stft: Optional[dict] = None, + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + complex_lstm: bool = True, + complex_norm: bool = True, + complex_relu: bool = True, + masking_mode: str = "E", + num_channels: int = 1, + sampling_rate=16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, + loss: Union[str, List, Any] = "mse", + metric: Union[str, List] = "mse", + ): + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) + if dataset is not None: + if sampling_rate != dataset.sampling_rate: + logging.warning( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) + sampling_rate = dataset.sampling_rate + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + stft = merge_dict(self.STFT_DEFAULTS, stft) + self.save_hyperparameters( + "encoder_decoder", + "lstm", + "stft", + "complex_lstm", + "complex_norm", + "masking_mode", + ) + self.complex_lstm = complex_lstm + self.complex_norm = complex_norm + self.masking_mode = masking_mode + + self.stft = ConvSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + self.istft = ConviSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + num_channels *= 2 + hidden_size = encoder_decoder["initial_output_channels"] + growth_factor = 2 + + for layer in range(encoder_decoder["depth"]): + + encoder_ = DCCRN_ENCODER( + num_channels, + hidden_size, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 1), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + self.encoder.append(encoder_) + + decoder_ = DCCRN_DECODER( + hidden_size + hidden_size, + num_channels, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 0), + output_padding=(encoder_decoder["output_padding"], 0), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + + self.decoder.insert(0, decoder_) + + if layer < encoder_decoder["depth"] - 3: + num_channels = hidden_size + hidden_size *= growth_factor + else: + num_channels = hidden_size + + kernel_size = hidden_size / 2 + hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) + + if self.complex_lstm: + lstms = [] + for layer in range(lstm["num_layers"]): + + if layer == 0: + input_size = int(hidden_size * kernel_size) + else: + input_size = lstm["hidden_size"] + + if layer == lstm["num_layers"] - 1: + projection_size = int(hidden_size * kernel_size) + else: + projection_size = None + + kwargs = { + "input_size": input_size, + "hidden_size": lstm["hidden_size"], + "num_layers": 1, + } + + lstms.append( + ComplexLSTM(projection_size=projection_size, **kwargs) + ) + self.lstm = nn.Sequential(*lstms) + else: + self.lstm = nn.LSTM( + input_size=hidden_size * kernel_size, + hidden_sizs=lstm["hidden_size"], + num_layers=lstm["num_layers"], + dropout=0.0, + batch_first=False, + ) + + def forward(self, waveform): + + return waveform