From 485a74fc4e35ddab68c15333d19c2b0dab56ba70 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 09:36:28 +0530 Subject: [PATCH 01/32] convt stft --- enhancer/utils/transforms.py | 63 ++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 enhancer/utils/transforms.py diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py new file mode 100644 index 0000000..bf481e8 --- /dev/null +++ b/enhancer/utils/transforms.py @@ -0,0 +1,63 @@ +from typing import Optional + +import numpy as np +import torch +from scipy.signal import get_window +from torch import nn + + +class ConvFFT(nn.Module): + def __init__( + self, + window_len: int, + nfft: Optional[int] = None, + window: str = "hamming", + ): + self.window_len = window_len + self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) + self.window = get_window(window, window_len) + + @property + def init_kernel(self): + + fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] + real, imag = np.real(fourier_basis), np.imag(fourier_basis) + kernel = np.concatenate([real, imag], 1).T + kernel *= self.window + return torch.from_numpy(kernel) + + +class ConvSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super(self, ConvSTFT).__init__( + window_len=window_len, nfft=nfft, window=window + ) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel) + + def forward(self, input): + pass + + +class ConviSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super(self, ConvSTFT).__init__( + window_len=window_len, nfft=nfft, window=window + ) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel) + + def forward(self, input): + pass From 23da02d47d890cba6a36f02720fb0ce185b7f203 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 26 Oct 2022 09:36:55 +0530 Subject: [PATCH 02/32] dccrn --- enhancer/models/complexnn/conv.py | 0 enhancer/models/complexnn/rnn.py | 0 enhancer/models/dccrn.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 enhancer/models/complexnn/conv.py create mode 100644 enhancer/models/complexnn/rnn.py create mode 100644 enhancer/models/dccrn.py diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py new file mode 100644 index 0000000..e69de29 diff --git a/enhancer/models/complexnn/rnn.py b/enhancer/models/complexnn/rnn.py new file mode 100644 index 0000000..e69de29 diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py new file mode 100644 index 0000000..e69de29 From 085a85d9ae84e6e5ae2d7d751899e26c2d08d3f3 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 27 Oct 2022 11:32:50 +0530 Subject: [PATCH 03/32] fourier transforms using cnn --- enhancer/utils/transforms.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py index bf481e8..bbbae90 100644 --- a/enhancer/utils/transforms.py +++ b/enhancer/utils/transforms.py @@ -2,6 +2,7 @@ from typing import Optional import numpy as np import torch +import torch.nn.functional as F from scipy.signal import get_window from torch import nn @@ -42,7 +43,22 @@ class ConvSTFT(ConvFFT): self.register_buffer("weight", self.init_kernel) def forward(self, input): - pass + + if input.dim() < 2: + raise ValueError( + f"Expected signal with shape 2 or 3 got {input.dim()}" + ) + elif input.dim() == 2: + input = input.unsqueeze(1) + else: + pass + input = F.pad( + input, + (self.window_len - self.hop_size, self.window_len - self.hop_size), + ) + output = F.conv1d(input, self.weight, stride=self.hop_size) + + return output class ConviSTFT(ConvFFT): @@ -58,6 +74,11 @@ class ConviSTFT(ConvFFT): ) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) + self.register_buffer("enframe", np.eye(window_len).unsqueeze(1)) - def forward(self, input): - pass + def forward(self, input, phase=None): + + if phase is not None: + real = input * torch.cos(phase) + imag = input * torch.sin(phase) + input = torch.cat([real, imag], 1) From c18a85b5c87976c16cce7ca44b709bd60497221c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 29 Oct 2022 11:34:51 +0530 Subject: [PATCH 04/32] stft --- enhancer/utils/transforms.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py index bbbae90..acfc323 100644 --- a/enhancer/utils/transforms.py +++ b/enhancer/utils/transforms.py @@ -14,9 +14,12 @@ class ConvFFT(nn.Module): nfft: Optional[int] = None, window: str = "hamming", ): + super().__init__() self.window_len = window_len self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) - self.window = get_window(window, window_len) + self.window = torch.from_numpy( + get_window(window, window_len).astype("float32") + ) @property def init_kernel(self): @@ -24,8 +27,9 @@ class ConvFFT(nn.Module): fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] real, imag = np.real(fourier_basis), np.imag(fourier_basis) kernel = np.concatenate([real, imag], 1).T + kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1) kernel *= self.window - return torch.from_numpy(kernel) + return kernel class ConvSTFT(ConvFFT): @@ -36,9 +40,7 @@ class ConvSTFT(ConvFFT): nfft: Optional[int] = None, window: str = "hamming", ): - super(self, ConvSTFT).__init__( - window_len=window_len, nfft=nfft, window=window - ) + super().__init__(window_len=window_len, nfft=nfft, window=window) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) @@ -69,12 +71,10 @@ class ConviSTFT(ConvFFT): nfft: Optional[int] = None, window: str = "hamming", ): - super(self, ConvSTFT).__init__( - window_len=window_len, nfft=nfft, window=window - ) + super().__init__(window_len=window_len, nfft=nfft, window=window) self.hop_size = hop_size if hop_size else window_len // 2 self.register_buffer("weight", self.init_kernel) - self.register_buffer("enframe", np.eye(window_len).unsqueeze(1)) + self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1)) def forward(self, input, phase=None): @@ -82,3 +82,10 @@ class ConviSTFT(ConvFFT): real = input * torch.cos(phase) imag = input * torch.sin(phase) input = torch.cat([real, imag], 1) + out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) + coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 + coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) + out = out / coeff + pad = self.window_len - self.hop_size + out = out[..., pad:-pad] + return out From cf1e5c07a9a4ea810e8f8214c79ff93d9a50a3a1 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 29 Oct 2022 11:35:35 +0530 Subject: [PATCH 05/32] test transforms --- tests/transforms_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 tests/transforms_test.py diff --git a/tests/transforms_test.py b/tests/transforms_test.py new file mode 100644 index 0000000..3053b09 --- /dev/null +++ b/tests/transforms_test.py @@ -0,0 +1,14 @@ +import torch + +from enhancer.utils.transforms import ConviSTFT, ConvSTFT + + +def test_stft_istft(): + sample_input = torch.rand(1, 1, 16000) + stft = ConvSTFT(window_len=400, hop_size=100, nfft=512) + istft = ConviSTFT(window_len=400, hop_size=100, nfft=512) + + with torch.no_grad(): + spectrogram = stft(sample_input) + waveform = istft(spectrogram) + assert sample_input.shape == waveform.shape From 6f6e7f7ad85ab92d9989404957624969a1c3e8f7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 29 Oct 2022 13:20:04 +0530 Subject: [PATCH 06/32] init --- enhancer/models/complexnn/__init__.py | 1 + enhancer/models/complexnn/conv.py | 77 +++++++++++++++++++++++++++ tests/models/complexnn_test.py | 13 +++++ 3 files changed, 91 insertions(+) create mode 100644 enhancer/models/complexnn/__init__.py create mode 100644 tests/models/complexnn_test.py diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py new file mode 100644 index 0000000..68f376e --- /dev/null +++ b/enhancer/models/complexnn/__init__.py @@ -0,0 +1 @@ +# from enhancer.models.complexnn.conv import ComplexConv2d diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py index e69de29..eddcce3 100644 --- a/enhancer/models/complexnn/conv.py +++ b/enhancer/models/complexnn/conv.py @@ -0,0 +1,77 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def init_weights(nnet): + nn.init.xavier_normal_(nnet.weight.data) + nn.init.constant(nnet.bias, 0.0) + return nnet + + +class ComplexConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + groups: int = 1, + dilation: int = 1, + ): + """ + Complex Conv2d (non-causal) + """ + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.dilation = dilation + + self.real_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = init_weights(self.imag_conv) + self.real_conv = init_weights(self.real_conv) + + def forward(self, input): + """ + forward + complex axis should be always 1 dim + """ + input = F.pad(input, [self.padding[1], self.padding[1], 0, 0]) + + real, imag = torch.chunk(input, 2, 1) + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + + return out diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py new file mode 100644 index 0000000..9ca5811 --- /dev/null +++ b/tests/models/complexnn_test.py @@ -0,0 +1,13 @@ +import torch + +from enhancer.models.complexnn.conv import ComplexConv2d + + +def test_complexconv2d(): + sample_input = torch.rand(1, 2, 256, 13) + conv = ComplexConv2d( + 2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1) + ) + with torch.no_grad(): + out = conv(sample_input) + assert out.shape == torch.Size([1, 32, 128, 14]) From 26cccc67721803c2b3fd0b0f988d76fce3f6278a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 31 Oct 2022 11:43:32 +0530 Subject: [PATCH 07/32] complex tranposed conv --- enhancer/models/complexnn/conv.py | 64 ++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py index eddcce3..55acc07 100644 --- a/enhancer/models/complexnn/conv.py +++ b/enhancer/models/complexnn/conv.py @@ -7,7 +7,7 @@ from torch import nn def init_weights(nnet): nn.init.xavier_normal_(nnet.weight.data) - nn.init.constant(nnet.bias, 0.0) + nn.init.constant_(nnet.bias, 0.0) return nnet @@ -57,7 +57,6 @@ class ComplexConv2d(nn.Module): def forward(self, input): """ - forward complex axis should be always 1 dim """ input = F.pad(input, [self.padding[1], self.padding[1], 0, 0]) @@ -75,3 +74,64 @@ class ComplexConv2d(nn.Module): out = torch.cat([real, imag], 1) return out + + +class ComplexConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + output_padding: Tuple[int, int] = (0, 0), + groups: int = 1, + ): + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.output_padding = output_padding + + self.real_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + self.imag_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + init_weights(self.real_conv) + init_weights(self.imag_conv) + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + + return out From 7abd266ab21a47d7757385938f5abad993975992 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 31 Oct 2022 11:43:50 +0530 Subject: [PATCH 08/32] test complexnn --- tests/models/complexnn_test.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 9ca5811..53ffba2 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,6 +1,6 @@ import torch -from enhancer.models.complexnn.conv import ComplexConv2d +from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d def test_complexconv2d(): @@ -11,3 +11,19 @@ def test_complexconv2d(): with torch.no_grad(): out = conv(sample_input) assert out.shape == torch.Size([1, 32, 128, 14]) + + +def test_complexconvtranspose2d(): + sample_input = torch.rand(1, 512, 4, 13) + conv = ComplexConvTranspose2d( + 256 * 2, + 128 * 2, + kernel_size=(5, 2), + stride=(2, 1), + padding=(2, 0), + output_padding=(1, 0), + ) + with torch.no_grad(): + out = conv(sample_input) + + assert out.shape == torch.Size([1, 256, 8, 14]) From 0b50a573e83eca67773778dda5f743e4f4652287 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 1 Nov 2022 10:35:30 +0530 Subject: [PATCH 09/32] complex lstm --- enhancer/models/complexnn/rnn.py | 66 ++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/enhancer/models/complexnn/rnn.py b/enhancer/models/complexnn/rnn.py index e69de29..7d19425 100644 --- a/enhancer/models/complexnn/rnn.py +++ b/enhancer/models/complexnn/rnn.py @@ -0,0 +1,66 @@ +from typing import List, Optional + +import torch +from torch import nn + + +class ComplexLSTM(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + projection_size: Optional[int] = None, + bidirectional: bool = False, + ): + super().__init__() + self.input_size = input_size // 2 + self.hidden_size = hidden_size // 2 + self.num_layers = num_layers + + self.real_lstm = nn.LSTM( + self.input_size, + self.hidden_size, + self.num_layers, + bidirectional=bidirectional, + batch_first=False, + ) + self.imag_lstm = nn.LSTM( + self.input_size, + self.hidden_size, + self.num_layers, + bidirectional=bidirectional, + batch_first=False, + ) + + bidirectional = 2 if bidirectional else 1 + if projection_size is not None: + self.projection_size = projection_size // 2 + self.real_linear = nn.Linear( + self.hidden_size * bidirectional, self.projection_size + ) + self.imag_linear = nn.Linear( + self.hidden_size * bidirectional, self.projection_size + ) + + def forward(self, input): + + if isinstance(input, List): + real, imag = input + else: + real, imag = torch.chunk(input, 2, 1) + + real_real = self.real_lstm(real)[0] + real_imag = self.imag_lstm(real)[0] + + imag_imag = self.imag_lstm(imag)[0] + imag_real = self.real_lstm(imag)[0] + + real = real_real - imag_imag + imag = imag_real + real_imag + + if self.projection_size is not None: + real = self.real_linear(real) + imag = self.imag_linear(imag) + + return [real, imag] From b1144e7b818822c57e2bc083e6832741f91531f8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 1 Nov 2022 10:35:49 +0530 Subject: [PATCH 10/32] tests complexnn --- tests/models/complexnn_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 53ffba2..74c2baa 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,6 +1,7 @@ import torch from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d +from enhancer.models.complexnn.rnn import ComplexLSTM def test_complexconv2d(): @@ -27,3 +28,13 @@ def test_complexconvtranspose2d(): out = conv(sample_input) assert out.shape == torch.Size([1, 256, 8, 14]) + + +def test_complexlstm(): + sample_input = torch.rand(13, 2, 128) + lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2) + with torch.no_grad(): + out = lstm(sample_input) + + assert out[0].shape == torch.Size([13, 1, 512]) + assert out[1].shape == torch.Size([13, 1, 512]) From e932dc6c75b73e3a3f25221ac78cbdf9d0dd8862 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 3 Nov 2022 11:37:58 +0530 Subject: [PATCH 11/32] batchnorm --- enhancer/models/complexnn/norm.py | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 enhancer/models/complexnn/norm.py diff --git a/enhancer/models/complexnn/norm.py b/enhancer/models/complexnn/norm.py new file mode 100644 index 0000000..eec2130 --- /dev/null +++ b/enhancer/models/complexnn/norm.py @@ -0,0 +1,72 @@ +import torch +from torch import nn + + +class ComplexBatchNorm(nn.Module): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: bool = True, + affine: bool = True, + track_running_stats: bool = True, + ): + self.num_features = num_features // 2 + self.affine = affine + self.momentum = momentum + self.track_running_stats = track_running_stats + + if self.affine: + values = torch.Tensor(self.num_features) + self.Wrr = nn.parameter.Parameter(values) + self.Wri = nn.parameter.Parameter(values) + self.Wii = nn.parameter.Parameter(values) + self.Br = nn.parameter.Parameter(values) + self.Bi = nn.parameter.Parameter(values) + else: + self.register_parameter("Wrr", None) + self.register_parameter("Wri", None) + self.register_parameter("Wii", None) + self.register_parameter("Br", None) + self.register_parameter("Bi", None) + + if self.track_running_stats: + values = torch.Tensor(self.num_features) + self.register_buffer("Mean_real", values) + self.register_buffer("Mean_imag", values) + self.register_buffer("Var_rr", values) + self.register_buffer("Var_ri", values) + self.register_buffer("Var_ii", values) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("Mean_real", None) + self.register_parameter("Mean_imag", None) + self.register_parameter("Var_rr", None) + self.register_parameter("Var_ri", None) + self.register_parameter("Var_ii", None) + self.register_parameter("num_batches_tracked", None) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.Wrr.data.fill_(1) + self.Wii.data.fill(1) + self.Wri.data.uniform_(-0.9, 0.9) + self.Br.data.fill_(0) + self.Bi.data.fill_(0) + self.reset_running_stats() + + def reset_running_stats(self): + if self.track_running_stats: + self.Mean_real.zero_() + self.Mean_imag.zero_() + self.Var_rr.fill_(1) + self.Var_ri.zero_() + self.Var_ii.fill_(1) + self.num_batches_tracked.zero_() + + def forward(self, input): + pass From da1b986d311579f60445e599d29cdf42306815d8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 3 Nov 2022 16:05:55 +0530 Subject: [PATCH 12/32] complex batchnorm 2d --- enhancer/models/complexnn/norm.py | 95 ++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/enhancer/models/complexnn/norm.py b/enhancer/models/complexnn/norm.py index eec2130..5dd0104 100644 --- a/enhancer/models/complexnn/norm.py +++ b/enhancer/models/complexnn/norm.py @@ -2,7 +2,7 @@ import torch from torch import nn -class ComplexBatchNorm(nn.Module): +class ComplexBatchNorm2D(nn.Module): def __init__( self, num_features: int, @@ -11,10 +11,18 @@ class ComplexBatchNorm(nn.Module): affine: bool = True, track_running_stats: bool = True, ): + """ + Complex batch normalization 2D + https://arxiv.org/abs/1705.09792 + + + """ + super().__init__() self.num_features = num_features // 2 self.affine = affine self.momentum = momentum self.track_running_stats = track_running_stats + self.eps = eps if self.affine: values = torch.Tensor(self.num_features) @@ -53,7 +61,7 @@ class ComplexBatchNorm(nn.Module): def reset_parameters(self): if self.affine: self.Wrr.data.fill_(1) - self.Wii.data.fill(1) + self.Wii.data.fill_(1) self.Wri.data.uniform_(-0.9, 0.9) self.Br.data.fill_(0) self.Bi.data.fill_(0) @@ -69,4 +77,85 @@ class ComplexBatchNorm(nn.Module): self.num_batches_tracked.zero_() def forward(self, input): - pass + + real, imag = torch.chunk(input, 2, 1) + exp_avg_factor = 0.0 + + training = self.training and self.track_running_stats + if training: + self.num_batches_tracked += 1 + if self.momentum is None: + exp_avg_factor = 1 / self.num_batches_tracked + else: + exp_avg_factor = self.momentum + + redux = [i for i in reversed(range(real.dim())) if i != 1] + vdim = [1] * real.dim() + vdim[1] = real.size(1) + + if training: + batch_mean_real, batch_mean_imag = real, imag + for dim in redux: + batch_mean_real = batch_mean_real.mean(dim, keepdim=True) + batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True) + if self.track_running_stats: + self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor) + self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor) + + else: + batch_mean_real = self.Mean_real.view(vdim) + batch_mean_imag = self.Mean_imag.view(vdim) + + real -= batch_mean_real + imag -= batch_mean_imag + + if training: + batch_var_rr = real * real + batch_var_ri = real * imag + batch_var_ii = imag * imag + for dim in redux: + batch_var_rr = batch_var_rr.mean(dim, keepdim=True) + batch_var_ri = batch_var_ri.mean(dim, keepdim=True) + batch_var_ii = batch_var_ii.mean(dim, keepdim=True) + if self.track_running_stats: + self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) + self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) + self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) + + batch_var_rr += self.eps + batch_var_ii += self.eps + + # Covariance matrics + # | batch_var_rr batch_var_ri | + # | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri + # Inverse square root of cov matrix by combining below two formulas + # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix + # https://mathworld.wolfram.com/MatrixInverse.html + + tau = batch_var_rr + batch_var_ii + s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri + t = (tau + 2 * s).sqrt() + + rst = 1 / (s * t) + Urr = (batch_var_ii + s) * rst + Uri = -batch_var_ri * rst + Uii = (batch_var_rr + s) * rst + + if self.affine: + Wrr, Wri, Wii = ( + self.Wrr.view(vdim), + self.Wri.view(vdim), + self.Wii.view(vdim), + ) + Zrr = (Wrr * Urr) + (Wri * Uri) + Zri = (Wrr * Uri) + (Wri * Uii) + Zir = (Wii * Uri) + (Wri * Urr) + Zii = (Wri * Uri) + (Wii * Uii) + else: + Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii + + yr = (Zrr * real) + (Zri * imag) + yi = (Zir * real) + (Zii * imag) + + outputs = torch.cat([yr, yi], 1) + return outputs From d3e052c5f36d441393bf8df7294e15cc5545712c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 3 Nov 2022 16:06:14 +0530 Subject: [PATCH 13/32] complex batchnorm 2d test --- tests/models/complexnn_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 74c2baa..8c18ed5 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,6 +1,7 @@ import torch from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d +from enhancer.models.complexnn.norm import ComplexBatchNorm2D from enhancer.models.complexnn.rnn import ComplexLSTM @@ -38,3 +39,12 @@ def test_complexlstm(): assert out[0].shape == torch.Size([13, 1, 512]) assert out[1].shape == torch.Size([13, 1, 512]) + + +def test_complexbatchnorm2d(): + sample_input = torch.rand(1, 64, 64, 14) + batchnorm = ComplexBatchNorm2D(num_features=64) + with torch.no_grad(): + out = batchnorm(sample_input) + + assert out.size() == sample_input.size() From 981763207ae7694d1298de2a7443a512329eded4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:35:57 +0530 Subject: [PATCH 14/32] init dccrn --- enhancer/models/dccrn.py | 240 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 240 insertions(+) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index e69de29..38c8145 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -0,0 +1,240 @@ +import logging +from typing import Any, List, Optional, Tuple, Union + +from torch import nn + +from enhancer.data import EnhancerDataset +from enhancer.models import Model +from enhancer.models.complexnn import ComplexConv2d, ComplexLSTM +from enhancer.models.complexnn.conv import ComplexConvTranspose2d +from enhancer.models.complexnn.utils import ComplexBatchNorm2D, ComplexRelu +from enhancer.utils.transforms import ConviSTFT, ConvSTFT +from enhancer.utils.utils import merge_dict + + +class DCCRN_ENCODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channel: int, + kernel_size: Tuple[int, int], + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 1), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + self.encoder = nn.Sequential( + ComplexConv2d( + in_channels, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + batchnorm(out_channel), + activation, + ) + + def forward(self, waveform): + + return self.encoder(waveform) + + +class DCCRN_DECODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 0), + output_padding: Tuple[int, int] = (1, 0), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + batchnorm(out_channels), + activation, + ) + + def forward(self, waveform): + + return self.decoder(waveform) + + +class DCCRN(Model): + + STFT_DEFAULTS = { + "window_len": 400, + "hop_size": 100, + "nfft": 512, + "window": "hamming", + } + + ED_DEFAULTS = { + "initial_output_channels": 32, + "depth": 6, + "kernel_size": 5, + "growth_factor": 2, + "stride": 2, + "padding": 2, + "output_padding": 1, + } + + LSTM_DEFAULTS = { + "num_layers": 2, + "hidden_size": 256, + } + + def __init__( + self, + stft: Optional[dict] = None, + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + complex_lstm: bool = True, + complex_norm: bool = True, + complex_relu: bool = True, + masking_mode: str = "E", + num_channels: int = 1, + sampling_rate=16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, + loss: Union[str, List, Any] = "mse", + metric: Union[str, List] = "mse", + ): + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) + if dataset is not None: + if sampling_rate != dataset.sampling_rate: + logging.warning( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) + sampling_rate = dataset.sampling_rate + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + stft = merge_dict(self.STFT_DEFAULTS, stft) + self.save_hyperparameters( + "encoder_decoder", + "lstm", + "stft", + "complex_lstm", + "complex_norm", + "masking_mode", + ) + self.complex_lstm = complex_lstm + self.complex_norm = complex_norm + self.masking_mode = masking_mode + + self.stft = ConvSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + self.istft = ConviSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + num_channels *= 2 + hidden_size = encoder_decoder["initial_output_channels"] + growth_factor = 2 + + for layer in range(encoder_decoder["depth"]): + + encoder_ = DCCRN_ENCODER( + num_channels, + hidden_size, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 1), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + self.encoder.append(encoder_) + + decoder_ = DCCRN_DECODER( + hidden_size + hidden_size, + num_channels, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 0), + output_padding=(encoder_decoder["output_padding"], 0), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + + self.decoder.insert(0, decoder_) + + if layer < encoder_decoder["depth"] - 3: + num_channels = hidden_size + hidden_size *= growth_factor + else: + num_channels = hidden_size + + kernel_size = hidden_size / 2 + hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) + + if self.complex_lstm: + lstms = [] + for layer in range(lstm["num_layers"]): + + if layer == 0: + input_size = int(hidden_size * kernel_size) + else: + input_size = lstm["hidden_size"] + + if layer == lstm["num_layers"] - 1: + projection_size = int(hidden_size * kernel_size) + else: + projection_size = None + + kwargs = { + "input_size": input_size, + "hidden_size": lstm["hidden_size"], + "num_layers": 1, + } + + lstms.append( + ComplexLSTM(projection_size=projection_size, **kwargs) + ) + self.lstm = nn.Sequential(*lstms) + else: + self.lstm = nn.LSTM( + input_size=hidden_size * kernel_size, + hidden_sizs=lstm["hidden_size"], + num_layers=lstm["num_layers"], + dropout=0.0, + batch_first=False, + ) + + def forward(self, waveform): + + return waveform From b98599f21e0e940246f041395b6cd2fe6f40e451 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:36:27 +0530 Subject: [PATCH 15/32] rename module --- .../models/complexnn/{norm.py => utils.py} | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) rename enhancer/models/complexnn/{norm.py => utils.py} (91%) diff --git a/enhancer/models/complexnn/norm.py b/enhancer/models/complexnn/utils.py similarity index 91% rename from enhancer/models/complexnn/norm.py rename to enhancer/models/complexnn/utils.py index 5dd0104..d5de558 100644 --- a/enhancer/models/complexnn/norm.py +++ b/enhancer/models/complexnn/utils.py @@ -76,6 +76,11 @@ class ComplexBatchNorm2D(nn.Module): self.Var_ii.fill_(1) self.num_batches_tracked.zero_() + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format( + **self.__dict__ + ) + def forward(self, input): real, imag = torch.chunk(input, 2, 1) @@ -159,3 +164,17 @@ class ComplexBatchNorm2D(nn.Module): outputs = torch.cat([yr, yi], 1) return outputs + + +class ComplexRelu(nn.Module): + def __init__(self): + super().__init__() + self.real_relu = nn.PReLU() + self.imag_relu = nn.PReLU() + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + real = self.real_relu(real) + imag = self.imag_relu(imag) + return torch.cat([real, imag], dim=1) From a3b20d5ddb6673a302c455e44f14ac6ead3239b7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:40:19 +0530 Subject: [PATCH 16/32] fix imports --- enhancer/models/dccrn.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index 38c8145..00d9f9d 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -5,9 +5,13 @@ from torch import nn from enhancer.data import EnhancerDataset from enhancer.models import Model -from enhancer.models.complexnn import ComplexConv2d, ComplexLSTM -from enhancer.models.complexnn.conv import ComplexConvTranspose2d -from enhancer.models.complexnn.utils import ComplexBatchNorm2D, ComplexRelu +from enhancer.models.complexnn import ( + ComplexBatchNorm2D, + ComplexConv2d, + ComplexConvTranspose2d, + ComplexLSTM, + ComplexRelu, +) from enhancer.utils.transforms import ConviSTFT, ConvSTFT from enhancer.utils.utils import merge_dict From e2e413f8f3b545abfd132d9f4dfddd22671566a8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:55:23 +0530 Subject: [PATCH 17/32] rmv --- enhancer/models/complexnn/__init__.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 enhancer/models/complexnn/__init__.py diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py deleted file mode 100644 index 68f376e..0000000 --- a/enhancer/models/complexnn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# from enhancer.models.complexnn.conv import ComplexConv2d From 438882092141cd4c2200329eae4e357bac4692b4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:58:16 +0530 Subject: [PATCH 18/32] add imports --- enhancer/models/complexnn/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 enhancer/models/complexnn/__init__.py diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py new file mode 100644 index 0000000..8f9cef6 --- /dev/null +++ b/enhancer/models/complexnn/__init__.py @@ -0,0 +1 @@ +from enhancer.models.complexnn.conv import ComplexConv2d # noqa From 2e4a3cd25403ceff5def92d79267ddcce97cbf52 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:58:50 +0530 Subject: [PATCH 19/32] add imports --- enhancer/models/complexnn/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py index 8f9cef6..aa47ad2 100644 --- a/enhancer/models/complexnn/__init__.py +++ b/enhancer/models/complexnn/__init__.py @@ -1 +1,4 @@ from enhancer.models.complexnn.conv import ComplexConv2d # noqa +from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa +from enhancer.models.complexnn.rnn import ComplexLSTM # noqa +from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa From 70d17f6586eb2d03b69c71f76962f0265519b0ef Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 5 Nov 2022 16:59:04 +0530 Subject: [PATCH 20/32] add imports --- enhancer/models/complexnn/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py index aa47ad2..918a261 100644 --- a/enhancer/models/complexnn/__init__.py +++ b/enhancer/models/complexnn/__init__.py @@ -2,3 +2,4 @@ from enhancer.models.complexnn.conv import ComplexConv2d # noqa from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa from enhancer.models.complexnn.rnn import ComplexLSTM # noqa from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa +from enhancer.models.complexnn.utils import ComplexRelu # noqa From c21f05e3073053e6a2d0cea063b441ffdb799c8d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:23:46 +0530 Subject: [PATCH 21/32] fix padding & init --- enhancer/models/complexnn/conv.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py index 55acc07..d9a4d0f 100644 --- a/enhancer/models/complexnn/conv.py +++ b/enhancer/models/complexnn/conv.py @@ -59,9 +59,10 @@ class ComplexConv2d(nn.Module): """ complex axis should be always 1 dim """ - input = F.pad(input, [self.padding[1], self.padding[1], 0, 0]) + input = F.pad(input, [self.padding[1], 0, 0, 0]) real, imag = torch.chunk(input, 2, 1) + real_real = self.real_conv(real) real_imag = self.imag_conv(real) @@ -72,7 +73,6 @@ class ComplexConv2d(nn.Module): imag = real_imag - imag_real out = torch.cat([real, imag], 1) - return out @@ -116,13 +116,12 @@ class ComplexConvTranspose2d(nn.Module): groups=self.groups, ) - init_weights(self.real_conv) - init_weights(self.imag_conv) + self.real_conv = init_weights(self.real_conv) + self.imag_conv = init_weights(self.imag_conv) def forward(self, input): real, imag = torch.chunk(input, 2, 1) - real_real = self.real_conv(real) real_imag = self.imag_conv(real) From 60fc4607d03c907d8832dc35fcbed363426da5d9 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:24:18 +0530 Subject: [PATCH 22/32] init projection_size as None --- enhancer/models/complexnn/rnn.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enhancer/models/complexnn/rnn.py b/enhancer/models/complexnn/rnn.py index 7d19425..847030b 100644 --- a/enhancer/models/complexnn/rnn.py +++ b/enhancer/models/complexnn/rnn.py @@ -42,6 +42,8 @@ class ComplexLSTM(nn.Module): self.imag_linear = nn.Linear( self.hidden_size * bidirectional, self.projection_size ) + else: + self.projection_size = None def forward(self, input): From d7f384791755456ec97963103099f819b6414d1f Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:24:47 +0530 Subject: [PATCH 23/32] add complex-cat --- enhancer/models/complexnn/utils.py | 37 +++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/enhancer/models/complexnn/utils.py b/enhancer/models/complexnn/utils.py index d5de558..1b6ff78 100644 --- a/enhancer/models/complexnn/utils.py +++ b/enhancer/models/complexnn/utils.py @@ -7,7 +7,7 @@ class ComplexBatchNorm2D(nn.Module): self, num_features: int, eps: float = 1e-5, - momentum: bool = True, + momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ): @@ -25,12 +25,11 @@ class ComplexBatchNorm2D(nn.Module): self.eps = eps if self.affine: - values = torch.Tensor(self.num_features) - self.Wrr = nn.parameter.Parameter(values) - self.Wri = nn.parameter.Parameter(values) - self.Wii = nn.parameter.Parameter(values) - self.Br = nn.parameter.Parameter(values) - self.Bi = nn.parameter.Parameter(values) + self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features)) else: self.register_parameter("Wrr", None) self.register_parameter("Wri", None) @@ -39,7 +38,7 @@ class ComplexBatchNorm2D(nn.Module): self.register_parameter("Bi", None) if self.track_running_stats: - values = torch.Tensor(self.num_features) + values = torch.zeros(self.num_features) self.register_buffer("Mean_real", values) self.register_buffer("Mean_imag", values) self.register_buffer("Var_rr", values) @@ -111,8 +110,8 @@ class ComplexBatchNorm2D(nn.Module): batch_mean_real = self.Mean_real.view(vdim) batch_mean_imag = self.Mean_imag.view(vdim) - real -= batch_mean_real - imag -= batch_mean_imag + real = real - batch_mean_real + imag = imag - batch_mean_imag if training: batch_var_rr = real * real @@ -141,7 +140,7 @@ class ComplexBatchNorm2D(nn.Module): s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri t = (tau + 2 * s).sqrt() - rst = 1 / (s * t) + rst = (s * t).reciprocal() Urr = (batch_var_ii + s) * rst Uri = -batch_var_ri * rst Uii = (batch_var_rr + s) * rst @@ -162,6 +161,10 @@ class ComplexBatchNorm2D(nn.Module): yr = (Zrr * real) + (Zri * imag) yi = (Zir * real) + (Zii * imag) + if self.affine: + yr = yr + self.Br.view(vdim) + yi = yi + self.Bi.view(vdim) + outputs = torch.cat([yr, yi], 1) return outputs @@ -178,3 +181,15 @@ class ComplexRelu(nn.Module): real = self.real_relu(real) imag = self.imag_relu(imag) return torch.cat([real, imag], dim=1) + + +def complex_cat(inputs, axis=1): + + real, imag = [], [] + for data in inputs: + real_data, imag_data = torch.chunk(data, 2, axis) + real.append(real_data) + imag.append(imag_data) + real = torch.cat(real, axis) + imag = torch.cat(imag, axis) + return torch.cat([real, imag], axis) From c1d5e56ec0c4b0290cddf185c371f8e0ede886f0 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:25:27 +0530 Subject: [PATCH 24/32] transforms test --- enhancer/utils/transforms.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py index acfc323..fbdb8f9 100644 --- a/enhancer/utils/transforms.py +++ b/enhancer/utils/transforms.py @@ -18,15 +18,16 @@ class ConvFFT(nn.Module): self.window_len = window_len self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) self.window = torch.from_numpy( - get_window(window, window_len).astype("float32") + get_window(window, window_len, fftbins=True).astype("float32") ) - @property - def init_kernel(self): + def init_kernel(self, inverse=False): fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] real, imag = np.real(fourier_basis), np.imag(fourier_basis) kernel = np.concatenate([real, imag], 1).T + if inverse: + kernel = np.linalg.pinv(kernel).T kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1) kernel *= self.window return kernel @@ -42,7 +43,7 @@ class ConvSTFT(ConvFFT): ): super().__init__(window_len=window_len, nfft=nfft, window=window) self.hop_size = hop_size if hop_size else window_len // 2 - self.register_buffer("weight", self.init_kernel) + self.register_buffer("weight", self.init_kernel()) def forward(self, input): @@ -73,7 +74,7 @@ class ConviSTFT(ConvFFT): ): super().__init__(window_len=window_len, nfft=nfft, window=window) self.hop_size = hop_size if hop_size else window_len // 2 - self.register_buffer("weight", self.init_kernel) + self.register_buffer("weight", self.init_kernel(True)) self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1)) def forward(self, input, phase=None): @@ -85,7 +86,7 @@ class ConviSTFT(ConvFFT): out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) - out = out / coeff + out = out / (coeff + 1e-8) pad = self.window_len - self.hop_size out = out[..., pad:-pad] return out From fc33bd83b68757cff1a3c73501cb3c7bb34740e8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:25:54 +0530 Subject: [PATCH 25/32] transforms test --- tests/transforms_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/transforms_test.py b/tests/transforms_test.py index 3053b09..89425ad 100644 --- a/tests/transforms_test.py +++ b/tests/transforms_test.py @@ -12,3 +12,7 @@ def test_stft_istft(): spectrogram = stft(sample_input) waveform = istft(spectrogram) assert sample_input.shape == waveform.shape + assert ( + torch.isclose(waveform, sample_input).sum().item() + > sample_input.shape[-1] // 2 + ) From 511d2141d4ee2229fcb52dad59966e4d581d918a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:26:51 +0530 Subject: [PATCH 26/32] DCCRN implementation --- enhancer/models/dccrn.py | 124 +++++++++++++++++++++++++++++++++------ 1 file changed, 105 insertions(+), 19 deletions(-) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index 00d9f9d..c6ee837 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -1,6 +1,8 @@ import logging from typing import Any, List, Optional, Tuple, Union +import torch +import torch.nn.functional as F from torch import nn from enhancer.data import EnhancerDataset @@ -12,6 +14,7 @@ from enhancer.models.complexnn import ( ComplexLSTM, ComplexRelu, ) +from enhancer.models.complexnn.utils import complex_cat from enhancer.utils.transforms import ConviSTFT, ConvSTFT from enhancer.utils.utils import merge_dict @@ -54,6 +57,7 @@ class DCCRN_DECODER(nn.Module): in_channels: int, out_channels: int, kernel_size: Tuple[int, int], + layer: int = 0, complex_norm: bool = True, complex_relu: bool = True, stride: Tuple[int, int] = (2, 1), @@ -64,18 +68,30 @@ class DCCRN_DECODER(nn.Module): batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d activation = ComplexRelu() if complex_relu else nn.PReLU() - self.decoder = nn.Sequential( - ComplexConvTranspose2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - output_padding=output_padding, - ), - batchnorm(out_channels), - activation, - ) + if layer != 0: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + batchnorm(out_channels), + activation, + ) + else: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + ) def forward(self, waveform): @@ -187,6 +203,7 @@ class DCCRN(Model): decoder_ = DCCRN_DECODER( hidden_size + hidden_size, num_channels, + layer=layer, kernel_size=(encoder_decoder["kernel_size"], 2), stride=(encoder_decoder["stride"], 1), padding=(encoder_decoder["padding"], 0), @@ -231,14 +248,83 @@ class DCCRN(Model): ) self.lstm = nn.Sequential(*lstms) else: - self.lstm = nn.LSTM( - input_size=hidden_size * kernel_size, - hidden_sizs=lstm["hidden_size"], - num_layers=lstm["num_layers"], - dropout=0.0, - batch_first=False, + self.lstm = nn.Sequential( + nn.LSTM( + input_size=hidden_size * kernel_size, + hidden_sizs=lstm["hidden_size"], + num_layers=lstm["num_layers"], + dropout=0.0, + batch_first=False, + )[0], + nn.Linear(lstm["hidden"], hidden_size * kernel_size), ) def forward(self, waveform): - return waveform + waveform_stft = self.stft(waveform) + real = waveform_stft[:, : self.stft.nfft // 2 + 1] + imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] + + mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9) + phase_spec = torch.atan2(imag, real) + complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:] + + encoder_outputs = [] + out = complex_spec + for _, encoder in enumerate(self.encoder): + out = encoder(out) + encoder_outputs.append(out) + + B, C, D, T = out.size() + out = out.permute(3, 0, 1, 2) + if self.complex_lstm: + + lstm_real = out[:, :, : C // 2] + lstm_imag = out[:, :, C // 2 :] + lstm_real = lstm_real.reshape(T, B, C // 2 * D) + lstm_imag = lstm_imag.reshape(T, B, C // 2 * D) + lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag]) + lstm_real = lstm_real.reshape(T, B, C // 2, D) + lstm_imag = lstm_imag.reshape(T, B, C // 2, D) + out = torch.cat([lstm_real, lstm_imag], 2) + else: + out = out.reshape(T, B, C * D) + out = self.lstm(out) + out = out.reshape(T, B, D, C) + + out = out.permute(1, 2, 3, 0) + for layer, decoder in enumerate(self.decoder): + skip_connection = encoder_outputs.pop(-1) + out = complex_cat([skip_connection, out]) + out = decoder(out) + out = out[..., 1:] + mask_real, mask_imag = out[:, 0], out[:, 1] + mask_real = F.pad(mask_real, [0, 0, 1, 0]) + mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) + if self.masking_mode == "E": + + mask_mag = torch.sqrt(mask_real**2 + mask_imag**2) + real_phase = mask_real / (mask_mag + 1e-8) + imag_phase = mask_imag / (mask_mag + 1e-8) + mask_phase = torch.atan2(imag_phase, real_phase) + mask_mag = torch.tanh(mask_mag) + est_mag = mask_mag * mag_spec + est_phase = mask_phase * phase_spec + # cos(theta) + isin(theta) + real = est_mag + torch.cos(est_phase) + imag = est_mag + torch.sin(est_phase) + + if self.masking_mode == "C": + + real = real * mask_real - imag * mask_imag + imag = real * mask_imag + imag * mask_real + + else: + + real = real * mask_real + imag = imag * mask_imag + + spec = torch.cat([real, imag], 1) + wav = self.istft(spec).squeeze(1) + wav = wav.clamp_(-1, 1) + return wav From 15c1d1ad947b655ca0d92058a1dc21e23ebbecad Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:52:11 +0530 Subject: [PATCH 27/32] fix batchnorm eval() mode --- enhancer/models/complexnn/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/enhancer/models/complexnn/utils.py b/enhancer/models/complexnn/utils.py index 1b6ff78..0c28f9b 100644 --- a/enhancer/models/complexnn/utils.py +++ b/enhancer/models/complexnn/utils.py @@ -125,6 +125,10 @@ class ComplexBatchNorm2D(nn.Module): self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) + else: + batch_var_rr = self.Var_rr.view(vdim) + batch_var_ii = self.Var_ii.view(vdim) + batch_var_ri = self.Var_ri.view(vdim) batch_var_rr += self.eps batch_var_ii += self.eps From 40e8722014b0e12ef47638046975a00951000b70 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:52:35 +0530 Subject: [PATCH 28/32] fix o/p shape --- enhancer/models/dccrn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index c6ee837..72d8a23 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -325,6 +325,6 @@ class DCCRN(Model): imag = imag * mask_imag spec = torch.cat([real, imag], 1) - wav = self.istft(spec).squeeze(1) + wav = self.istft(spec) wav = wav.clamp_(-1, 1) return wav From 1a4102cc53d66525f0dc732b5c0c824185968ce2 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 10:53:08 +0530 Subject: [PATCH 29/32] dccrn --- enhancer/cli/train_config/model/DCCRN.yaml | 25 ++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 enhancer/cli/train_config/model/DCCRN.yaml diff --git a/enhancer/cli/train_config/model/DCCRN.yaml b/enhancer/cli/train_config/model/DCCRN.yaml new file mode 100644 index 0000000..3190391 --- /dev/null +++ b/enhancer/cli/train_config/model/DCCRN.yaml @@ -0,0 +1,25 @@ +_target_: enhancer.models.dccrn.DCCRN +num_channels: 1 +sampling_rate : 16000 +complex_lstm : True +complex_norm : True +complex_relu : True +masking_mode : True + +encoder_decoder: + initial_output_channels : 32 + depth : 6 + kernel_size : 5 + growth_factor : 2 + stride : 2 + padding : 2 + output_padding : 1 + +lstm: + num_layers : 2 + hidden_size : 256 + +stft: + window_len : 400 + hop_size : 100 + nfft : 512 From 77699ce7f90dc17a9904dd3c6955e528f4a2100d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 11:15:30 +0530 Subject: [PATCH 30/32] fix tests --- tests/models/complexnn_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py index 8c18ed5..524a6cf 100644 --- a/tests/models/complexnn_test.py +++ b/tests/models/complexnn_test.py @@ -1,8 +1,8 @@ import torch from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d -from enhancer.models.complexnn.norm import ComplexBatchNorm2D from enhancer.models.complexnn.rnn import ComplexLSTM +from enhancer.models.complexnn.utils import ComplexBatchNorm2D def test_complexconv2d(): @@ -12,7 +12,7 @@ def test_complexconv2d(): ) with torch.no_grad(): out = conv(sample_input) - assert out.shape == torch.Size([1, 32, 128, 14]) + assert out.shape == torch.Size([1, 32, 128, 13]) def test_complexconvtranspose2d(): From 6573bc4c5e7b1798c8747c499f65210be9e7993d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 11:33:00 +0530 Subject: [PATCH 31/32] ensure num_channels --- enhancer/models/dccrn.py | 8 ++++++++ enhancer/models/demucs.py | 6 +++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py index 72d8a23..7b1e5b1 100644 --- a/enhancer/models/dccrn.py +++ b/enhancer/models/dccrn.py @@ -261,6 +261,14 @@ class DCCRN(Model): def forward(self, waveform): + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" + ) + waveform_stft = self.stft(waveform) real = waveform_stft[:, : self.stft.nfft // 2 + 1] imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index e5fa945..fafb84e 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -204,9 +204,9 @@ class Demucs(Model): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1) != 1: - raise TypeError( - f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" ) if self.normalize: waveform = waveform.mean(dim=1, keepdim=True) From 6626ad75e71f8639549e372ccad0f856f1ceb373 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 7 Nov 2022 11:34:21 +0530 Subject: [PATCH 32/32] fix tests --- tests/models/demucs_test.py | 2 +- tests/models/test_dccrn.py | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 tests/models/test_dccrn.py diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index f5a0ec4..29e030e 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -30,7 +30,7 @@ def test_forward(batch_size, samples): data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): _ = model(data) diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py new file mode 100644 index 0000000..96a853b --- /dev/null +++ b/tests/models/test_dccrn.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models.dccrn import DCCRN +from enhancer.utils.config import Files + + +@pytest.fixture +def vctk_dataset(): + root_dir = "tests/data/vctk" + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + return dataset + + +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): + model = DCCRN() + model.eval() + + data = torch.rand(batch_size, 1, samples, requires_grad=False) + with torch.no_grad(): + _ = model(data) + + data = torch.rand(batch_size, 2, samples, requires_grad=False) + with torch.no_grad(): + with pytest.raises(ValueError): + _ = model(data) + + +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): + with torch.no_grad(): + _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)