DCCRN implementation
This commit is contained in:
parent
fc33bd83b6
commit
511d2141d4
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from enhancer.data import EnhancerDataset
|
||||
|
|
@ -12,6 +14,7 @@ from enhancer.models.complexnn import (
|
|||
ComplexLSTM,
|
||||
ComplexRelu,
|
||||
)
|
||||
from enhancer.models.complexnn.utils import complex_cat
|
||||
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
||||
from enhancer.utils.utils import merge_dict
|
||||
|
||||
|
|
@ -54,6 +57,7 @@ class DCCRN_DECODER(nn.Module):
|
|||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Tuple[int, int],
|
||||
layer: int = 0,
|
||||
complex_norm: bool = True,
|
||||
complex_relu: bool = True,
|
||||
stride: Tuple[int, int] = (2, 1),
|
||||
|
|
@ -64,18 +68,30 @@ class DCCRN_DECODER(nn.Module):
|
|||
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
||||
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
ComplexConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
),
|
||||
batchnorm(out_channels),
|
||||
activation,
|
||||
)
|
||||
if layer != 0:
|
||||
self.decoder = nn.Sequential(
|
||||
ComplexConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
),
|
||||
batchnorm(out_channels),
|
||||
activation,
|
||||
)
|
||||
else:
|
||||
self.decoder = nn.Sequential(
|
||||
ComplexConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, waveform):
|
||||
|
||||
|
|
@ -187,6 +203,7 @@ class DCCRN(Model):
|
|||
decoder_ = DCCRN_DECODER(
|
||||
hidden_size + hidden_size,
|
||||
num_channels,
|
||||
layer=layer,
|
||||
kernel_size=(encoder_decoder["kernel_size"], 2),
|
||||
stride=(encoder_decoder["stride"], 1),
|
||||
padding=(encoder_decoder["padding"], 0),
|
||||
|
|
@ -231,14 +248,83 @@ class DCCRN(Model):
|
|||
)
|
||||
self.lstm = nn.Sequential(*lstms)
|
||||
else:
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=hidden_size * kernel_size,
|
||||
hidden_sizs=lstm["hidden_size"],
|
||||
num_layers=lstm["num_layers"],
|
||||
dropout=0.0,
|
||||
batch_first=False,
|
||||
self.lstm = nn.Sequential(
|
||||
nn.LSTM(
|
||||
input_size=hidden_size * kernel_size,
|
||||
hidden_sizs=lstm["hidden_size"],
|
||||
num_layers=lstm["num_layers"],
|
||||
dropout=0.0,
|
||||
batch_first=False,
|
||||
)[0],
|
||||
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
|
||||
)
|
||||
|
||||
def forward(self, waveform):
|
||||
|
||||
return waveform
|
||||
waveform_stft = self.stft(waveform)
|
||||
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
||||
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
||||
|
||||
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
|
||||
phase_spec = torch.atan2(imag, real)
|
||||
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
|
||||
|
||||
encoder_outputs = []
|
||||
out = complex_spec
|
||||
for _, encoder in enumerate(self.encoder):
|
||||
out = encoder(out)
|
||||
encoder_outputs.append(out)
|
||||
|
||||
B, C, D, T = out.size()
|
||||
out = out.permute(3, 0, 1, 2)
|
||||
if self.complex_lstm:
|
||||
|
||||
lstm_real = out[:, :, : C // 2]
|
||||
lstm_imag = out[:, :, C // 2 :]
|
||||
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
|
||||
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
|
||||
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
|
||||
lstm_real = lstm_real.reshape(T, B, C // 2, D)
|
||||
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
|
||||
out = torch.cat([lstm_real, lstm_imag], 2)
|
||||
else:
|
||||
out = out.reshape(T, B, C * D)
|
||||
out = self.lstm(out)
|
||||
out = out.reshape(T, B, D, C)
|
||||
|
||||
out = out.permute(1, 2, 3, 0)
|
||||
for layer, decoder in enumerate(self.decoder):
|
||||
skip_connection = encoder_outputs.pop(-1)
|
||||
out = complex_cat([skip_connection, out])
|
||||
out = decoder(out)
|
||||
out = out[..., 1:]
|
||||
mask_real, mask_imag = out[:, 0], out[:, 1]
|
||||
mask_real = F.pad(mask_real, [0, 0, 1, 0])
|
||||
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
|
||||
if self.masking_mode == "E":
|
||||
|
||||
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
|
||||
real_phase = mask_real / (mask_mag + 1e-8)
|
||||
imag_phase = mask_imag / (mask_mag + 1e-8)
|
||||
mask_phase = torch.atan2(imag_phase, real_phase)
|
||||
mask_mag = torch.tanh(mask_mag)
|
||||
est_mag = mask_mag * mag_spec
|
||||
est_phase = mask_phase * phase_spec
|
||||
# cos(theta) + isin(theta)
|
||||
real = est_mag + torch.cos(est_phase)
|
||||
imag = est_mag + torch.sin(est_phase)
|
||||
|
||||
if self.masking_mode == "C":
|
||||
|
||||
real = real * mask_real - imag * mask_imag
|
||||
imag = real * mask_imag + imag * mask_real
|
||||
|
||||
else:
|
||||
|
||||
real = real * mask_real
|
||||
imag = imag * mask_imag
|
||||
|
||||
spec = torch.cat([real, imag], 1)
|
||||
wav = self.istft(spec).squeeze(1)
|
||||
wav = wav.clamp_(-1, 1)
|
||||
return wav
|
||||
|
|
|
|||
Loading…
Reference in New Issue