import logging from typing import Any, List, Optional, Tuple, Union from torch import nn from enhancer.data import EnhancerDataset from enhancer.models import Model from enhancer.models.complexnn import ( ComplexBatchNorm2D, ComplexConv2d, ComplexConvTranspose2d, ComplexLSTM, ComplexRelu, ) from enhancer.utils.transforms import ConviSTFT, ConvSTFT from enhancer.utils.utils import merge_dict class DCCRN_ENCODER(nn.Module): def __init__( self, in_channels: int, out_channel: int, kernel_size: Tuple[int, int], complex_norm: bool = True, complex_relu: bool = True, stride: Tuple[int, int] = (2, 1), padding: Tuple[int, int] = (2, 1), ): super().__init__() batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d activation = ComplexRelu() if complex_relu else nn.PReLU() self.encoder = nn.Sequential( ComplexConv2d( in_channels, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, ), batchnorm(out_channel), activation, ) def forward(self, waveform): return self.encoder(waveform) class DCCRN_DECODER(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple[int, int], complex_norm: bool = True, complex_relu: bool = True, stride: Tuple[int, int] = (2, 1), padding: Tuple[int, int] = (2, 0), output_padding: Tuple[int, int] = (1, 0), ): super().__init__() batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d activation = ComplexRelu() if complex_relu else nn.PReLU() self.decoder = nn.Sequential( ComplexConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, ), batchnorm(out_channels), activation, ) def forward(self, waveform): return self.decoder(waveform) class DCCRN(Model): STFT_DEFAULTS = { "window_len": 400, "hop_size": 100, "nfft": 512, "window": "hamming", } ED_DEFAULTS = { "initial_output_channels": 32, "depth": 6, "kernel_size": 5, "growth_factor": 2, "stride": 2, "padding": 2, "output_padding": 1, } LSTM_DEFAULTS = { "num_layers": 2, "hidden_size": 256, } def __init__( self, stft: Optional[dict] = None, encoder_decoder: Optional[dict] = None, lstm: Optional[dict] = None, complex_lstm: bool = True, complex_norm: bool = True, complex_relu: bool = True, masking_mode: str = "E", num_channels: int = 1, sampling_rate=16000, lr: float = 1e-3, dataset: Optional[EnhancerDataset] = None, duration: Optional[float] = None, loss: Union[str, List, Any] = "mse", metric: Union[str, List] = "mse", ): duration = ( dataset.duration if isinstance(dataset, EnhancerDataset) else None ) if dataset is not None: if sampling_rate != dataset.sampling_rate: logging.warning( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate super().__init__( num_channels=num_channels, sampling_rate=sampling_rate, lr=lr, dataset=dataset, duration=duration, loss=loss, metric=metric, ) encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) lstm = merge_dict(self.LSTM_DEFAULTS, lstm) stft = merge_dict(self.STFT_DEFAULTS, stft) self.save_hyperparameters( "encoder_decoder", "lstm", "stft", "complex_lstm", "complex_norm", "masking_mode", ) self.complex_lstm = complex_lstm self.complex_norm = complex_norm self.masking_mode = masking_mode self.stft = ConvSTFT( stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] ) self.istft = ConviSTFT( stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] ) self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() num_channels *= 2 hidden_size = encoder_decoder["initial_output_channels"] growth_factor = 2 for layer in range(encoder_decoder["depth"]): encoder_ = DCCRN_ENCODER( num_channels, hidden_size, kernel_size=(encoder_decoder["kernel_size"], 2), stride=(encoder_decoder["stride"], 1), padding=(encoder_decoder["padding"], 1), complex_norm=complex_norm, complex_relu=complex_relu, ) self.encoder.append(encoder_) decoder_ = DCCRN_DECODER( hidden_size + hidden_size, num_channels, kernel_size=(encoder_decoder["kernel_size"], 2), stride=(encoder_decoder["stride"], 1), padding=(encoder_decoder["padding"], 0), output_padding=(encoder_decoder["output_padding"], 0), complex_norm=complex_norm, complex_relu=complex_relu, ) self.decoder.insert(0, decoder_) if layer < encoder_decoder["depth"] - 3: num_channels = hidden_size hidden_size *= growth_factor else: num_channels = hidden_size kernel_size = hidden_size / 2 hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) if self.complex_lstm: lstms = [] for layer in range(lstm["num_layers"]): if layer == 0: input_size = int(hidden_size * kernel_size) else: input_size = lstm["hidden_size"] if layer == lstm["num_layers"] - 1: projection_size = int(hidden_size * kernel_size) else: projection_size = None kwargs = { "input_size": input_size, "hidden_size": lstm["hidden_size"], "num_layers": 1, } lstms.append( ComplexLSTM(projection_size=projection_size, **kwargs) ) self.lstm = nn.Sequential(*lstms) else: self.lstm = nn.LSTM( input_size=hidden_size * kernel_size, hidden_sizs=lstm["hidden_size"], num_layers=lstm["num_layers"], dropout=0.0, batch_first=False, ) def forward(self, waveform): return waveform