From 511d2141d4ee2229fcb52dad59966e4d581d918a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:26:51 +0530 Subject: [PATCH] DCCRN implementation --- enhancer/models/dccrn.py | 124 +++++++++++++++++++++++++++++++++------ 1 file changed, 105 insertions(+), 19 deletions(-) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index 00d9f9d..c6ee837 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -1,6 +1,8 @@ import logging from typing import Any, List, Optional, Tuple, Union +import torch +import torch.nn.functional as F from torch import nn from enhancer.data import EnhancerDataset @@ -12,6 +14,7 @@ from enhancer.models.complexnn import ( ComplexLSTM, ComplexRelu, ) +from enhancer.models.complexnn.utils import complex_cat from enhancer.utils.transforms import ConviSTFT, ConvSTFT from enhancer.utils.utils import merge_dict @@ -54,6 +57,7 @@ class DCCRN_DECODER(nn.Module): in_channels: int, out_channels: int, kernel_size: Tuple[int, int], + layer: int = 0, complex_norm: bool = True, complex_relu: bool = True, stride: Tuple[int, int] = (2, 1), @@ -64,18 +68,30 @@ class DCCRN_DECODER(nn.Module): batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d activation = ComplexRelu() if complex_relu else nn.PReLU() - self.decoder = nn.Sequential( - ComplexConvTranspose2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - ), - batchnorm(out_channels), - activation, - ) + if layer != 0: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + batchnorm(out_channels), + activation, + ) + else: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + ) def forward(self, waveform): @@ -187,6 +203,7 @@ class DCCRN(Model): decoder_ = DCCRN_DECODER( hidden_size + hidden_size, num_channels, + layer=layer, kernel_size=(encoder_decoder["kernel_size"], 2), stride=(encoder_decoder["stride"], 1), padding=(encoder_decoder["padding"], 0), @@ -231,14 +248,83 @@ class DCCRN(Model): ) self.lstm = nn.Sequential(*lstms) else: - self.lstm = nn.LSTM( - input_size=hidden_size * kernel_size, - hidden_sizs=lstm["hidden_size"], - num_layers=lstm["num_layers"], - dropout=0.0, - batch_first=False, + self.lstm = nn.Sequential( + nn.LSTM( + input_size=hidden_size * kernel_size, + hidden_sizs=lstm["hidden_size"], + num_layers=lstm["num_layers"], + dropout=0.0, + batch_first=False, + )[0], + nn.Linear(lstm["hidden"], hidden_size * kernel_size), ) def forward(self, waveform): - return waveform + waveform_stft = self.stft(waveform) + real = waveform_stft[:, : self.stft.nfft // 2 + 1] + imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] + + mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9) + phase_spec = torch.atan2(imag, real) + complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:] + + encoder_outputs = [] + out = complex_spec + for _, encoder in enumerate(self.encoder): + out = encoder(out) + encoder_outputs.append(out) + + B, C, D, T = out.size() + out = out.permute(3, 0, 1, 2) + if self.complex_lstm: + + lstm_real = out[:, :, : C // 2] + lstm_imag = out[:, :, C // 2 :] + lstm_real = lstm_real.reshape(T, B, C // 2 * D) + lstm_imag = lstm_imag.reshape(T, B, C // 2 * D) + lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag]) + lstm_real = lstm_real.reshape(T, B, C // 2, D) + lstm_imag = lstm_imag.reshape(T, B, C // 2, D) + out = torch.cat([lstm_real, lstm_imag], 2) + else: + out = out.reshape(T, B, C * D) + out = self.lstm(out) + out = out.reshape(T, B, D, C) + + out = out.permute(1, 2, 3, 0) + for layer, decoder in enumerate(self.decoder): + skip_connection = encoder_outputs.pop(-1) + out = complex_cat([skip_connection, out]) + out = decoder(out) + out = out[..., 1:] + mask_real, mask_imag = out[:, 0], out[:, 1] + mask_real = F.pad(mask_real, [0, 0, 1, 0]) + mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) + if self.masking_mode == "E": + + mask_mag = torch.sqrt(mask_real**2 + mask_imag**2) + real_phase = mask_real / (mask_mag + 1e-8) + imag_phase = mask_imag / (mask_mag + 1e-8) + mask_phase = torch.atan2(imag_phase, real_phase) + mask_mag = torch.tanh(mask_mag) + est_mag = mask_mag * mag_spec + est_phase = mask_phase * phase_spec + # cos(theta) + isin(theta) + real = est_mag + torch.cos(est_phase) + imag = est_mag + torch.sin(est_phase) + + if self.masking_mode == "C": + + real = real * mask_real - imag * mask_imag + imag = real * mask_imag + imag * mask_real + + else: + + real = real * mask_real + imag = imag * mask_imag + + spec = torch.cat([real, imag], 1) + wav = self.istft(spec).squeeze(1) + wav = wav.clamp_(-1, 1) + return wav