Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-11-07 11:54:25 +05:30
commit ce04720e59
14 changed files with 1011 additions and 25 deletions

View File

@ -0,0 +1,25 @@
_target_: enhancer.models.dccrn.DCCRN
num_channels: 1
sampling_rate : 16000
complex_lstm : True
complex_norm : True
complex_relu : True
masking_mode : True
encoder_decoder:
initial_output_channels : 32
depth : 6
kernel_size : 5
growth_factor : 2
stride : 2
padding : 2
output_padding : 1
lstm:
num_layers : 2
hidden_size : 256
stft:
window_len : 400
hop_size : 100
nfft : 512

View File

@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule):
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_minutes: float = 0.20, min_valid_minutes: float = 0.20,
duration: float = 1.0, duration: float = 1.0,
stride=None, stride=None,
sampling_rate: int = 48000, sampling_rate: int = 48000,
@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
if valid_minutes > 0.0: if min_valid_minutes > 0.0:
self.valid_minutes = valid_minutes self.min_valid_minutes = min_valid_minutes
else: else:
raise ValueError("valid_minutes must be greater than 0") raise ValueError("min_valid_minutes must be greater than 0")
self.augmentations = augmentations self.augmentations = augmentations
@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule):
) )
train_data = fp.prepare_matching_dict() train_data = fp.prepare_matching_dict()
train_data, self.val_data = self.train_valid_split( train_data, self.val_data = self.train_valid_split(
train_data, valid_minutes=self.valid_minutes, random_state=42 train_data,
min_valid_minutes=self.min_valid_minutes,
random_state=42,
) )
self.train_data = self.prepare_traindata(train_data) self.train_data = self.prepare_traindata(train_data)
@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule):
self._test = self.prepare_mapstype(test_data) self._test = self.prepare_mapstype(test_data)
def train_valid_split( def train_valid_split(
self, data, valid_minutes: float = 20, random_state: int = 42 self, data, min_valid_minutes: float = 20, random_state: int = 42
): ):
valid_minutes *= 60 min_valid_minutes *= 60
valid_sec_now = 0.0 valid_sec_now = 0.0
valid_indices = [] valid_indices = []
all_speakers = np.unique( all_speakers = np.unique(
@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule):
possible_indices = list(range(0, len(all_speakers))) possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers)) rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes: while valid_sec_now <= min_valid_minutes:
speaker_index = rng.choice(possible_indices) speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index) possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index] speaker_name = all_speakers[speaker_index]
@ -257,6 +259,9 @@ class EnhancerDataset(TaskDataset):
files : Files files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass) folder names (refer enhancer.utils.Files dataclass)
min_valid_minutes: float
minimum validation split size time in minutes
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
duration : float duration : float
expected audio duration of single audio sample for training expected audio duration of single audio sample for training
sampling_rate : int sampling_rate : int
@ -271,6 +276,7 @@ class EnhancerDataset(TaskDataset):
use one_to_many mapping for multiple noisy files for each clean file use one_to_many mapping for multiple noisy files for each clean file
""" """
def __init__( def __init__(
@ -278,7 +284,7 @@ class EnhancerDataset(TaskDataset):
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_minutes=5.0, min_valid_minutes=5.0,
duration=1.0, duration=1.0,
stride=None, stride=None,
sampling_rate=48000, sampling_rate=48000,
@ -292,7 +298,7 @@ class EnhancerDataset(TaskDataset):
name=name, name=name,
root_dir=root_dir, root_dir=root_dir,
files=files, files=files,
valid_minutes=valid_minutes, min_valid_minutes=min_valid_minutes,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
duration=duration, duration=duration,
matching_function=matching_function, matching_function=matching_function,

View File

@ -0,0 +1,5 @@
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
from enhancer.models.complexnn.utils import ComplexRelu # noqa

View File

@ -0,0 +1,136 @@
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
def init_weights(nnet):
nn.init.xavier_normal_(nnet.weight.data)
nn.init.constant_(nnet.bias, 0.0)
return nnet
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
groups: int = 1,
dilation: int = 1,
):
"""
Complex Conv2d (non-causal)
"""
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.real_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = init_weights(self.imag_conv)
self.real_conv = init_weights(self.real_conv)
def forward(self, input):
"""
complex axis should be always 1 dim
"""
input = F.pad(input, [self.padding[1], 0, 0, 0])
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out
class ComplexConvTranspose2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
output_padding: Tuple[int, int] = (0, 0),
groups: int = 1,
):
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.output_padding = output_padding
self.real_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.imag_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.real_conv = init_weights(self.real_conv)
self.imag_conv = init_weights(self.imag_conv)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out

View File

@ -0,0 +1,68 @@
from typing import List, Optional
import torch
from torch import nn
class ComplexLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
projection_size: Optional[int] = None,
bidirectional: bool = False,
):
super().__init__()
self.input_size = input_size // 2
self.hidden_size = hidden_size // 2
self.num_layers = num_layers
self.real_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
self.imag_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
bidirectional = 2 if bidirectional else 1
if projection_size is not None:
self.projection_size = projection_size // 2
self.real_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
self.imag_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
else:
self.projection_size = None
def forward(self, input):
if isinstance(input, List):
real, imag = input
else:
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_lstm(real)[0]
real_imag = self.imag_lstm(real)[0]
imag_imag = self.imag_lstm(imag)[0]
imag_real = self.real_lstm(imag)[0]
real = real_real - imag_imag
imag = imag_real + real_imag
if self.projection_size is not None:
real = self.real_linear(real)
imag = self.imag_linear(imag)
return [real, imag]

View File

@ -0,0 +1,199 @@
import torch
from torch import nn
class ComplexBatchNorm2D(nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
"""
Complex batch normalization 2D
https://arxiv.org/abs/1705.09792
"""
super().__init__()
self.num_features = num_features // 2
self.affine = affine
self.momentum = momentum
self.track_running_stats = track_running_stats
self.eps = eps
if self.affine:
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
else:
self.register_parameter("Wrr", None)
self.register_parameter("Wri", None)
self.register_parameter("Wii", None)
self.register_parameter("Br", None)
self.register_parameter("Bi", None)
if self.track_running_stats:
values = torch.zeros(self.num_features)
self.register_buffer("Mean_real", values)
self.register_buffer("Mean_imag", values)
self.register_buffer("Var_rr", values)
self.register_buffer("Var_ri", values)
self.register_buffer("Var_ii", values)
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_parameter("Mean_real", None)
self.register_parameter("Mean_imag", None)
self.register_parameter("Var_rr", None)
self.register_parameter("Var_ri", None)
self.register_parameter("Var_ii", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
def reset_parameters(self):
if self.affine:
self.Wrr.data.fill_(1)
self.Wii.data.fill_(1)
self.Wri.data.uniform_(-0.9, 0.9)
self.Br.data.fill_(0)
self.Bi.data.fill_(0)
self.reset_running_stats()
def reset_running_stats(self):
if self.track_running_stats:
self.Mean_real.zero_()
self.Mean_imag.zero_()
self.Var_rr.fill_(1)
self.Var_ri.zero_()
self.Var_ii.fill_(1)
self.num_batches_tracked.zero_()
def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
**self.__dict__
)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
exp_avg_factor = 0.0
training = self.training and self.track_running_stats
if training:
self.num_batches_tracked += 1
if self.momentum is None:
exp_avg_factor = 1 / self.num_batches_tracked
else:
exp_avg_factor = self.momentum
redux = [i for i in reversed(range(real.dim())) if i != 1]
vdim = [1] * real.dim()
vdim[1] = real.size(1)
if training:
batch_mean_real, batch_mean_imag = real, imag
for dim in redux:
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
if self.track_running_stats:
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
else:
batch_mean_real = self.Mean_real.view(vdim)
batch_mean_imag = self.Mean_imag.view(vdim)
real = real - batch_mean_real
imag = imag - batch_mean_imag
if training:
batch_var_rr = real * real
batch_var_ri = real * imag
batch_var_ii = imag * imag
for dim in redux:
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
if self.track_running_stats:
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
else:
batch_var_rr = self.Var_rr.view(vdim)
batch_var_ii = self.Var_ii.view(vdim)
batch_var_ri = self.Var_ri.view(vdim)
batch_var_rr += self.eps
batch_var_ii += self.eps
# Covariance matrics
# | batch_var_rr batch_var_ri |
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
# Inverse square root of cov matrix by combining below two formulas
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
# https://mathworld.wolfram.com/MatrixInverse.html
tau = batch_var_rr + batch_var_ii
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
t = (tau + 2 * s).sqrt()
rst = (s * t).reciprocal()
Urr = (batch_var_ii + s) * rst
Uri = -batch_var_ri * rst
Uii = (batch_var_rr + s) * rst
if self.affine:
Wrr, Wri, Wii = (
self.Wrr.view(vdim),
self.Wri.view(vdim),
self.Wii.view(vdim),
)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wii * Uri) + (Wri * Urr)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * real) + (Zri * imag)
yi = (Zir * real) + (Zii * imag)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], 1)
return outputs
class ComplexRelu(nn.Module):
def __init__(self):
super().__init__()
self.real_relu = nn.PReLU()
self.imag_relu = nn.PReLU()
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real = self.real_relu(real)
imag = self.imag_relu(imag)
return torch.cat([real, imag], dim=1)
def complex_cat(inputs, axis=1):
real, imag = [], []
for data in inputs:
real_data, imag_data = torch.chunk(data, 2, axis)
real.append(real_data)
imag.append(imag_data)
real = torch.cat(real, axis)
imag = torch.cat(imag, axis)
return torch.cat([real, imag], axis)

338
enhancer/models/dccrn.py Normal file
View File

@ -0,0 +1,338 @@
import logging
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from enhancer.data import EnhancerDataset
from enhancer.models import Model
from enhancer.models.complexnn import (
ComplexBatchNorm2D,
ComplexConv2d,
ComplexConvTranspose2d,
ComplexLSTM,
ComplexRelu,
)
from enhancer.models.complexnn.utils import complex_cat
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
from enhancer.utils.utils import merge_dict
class DCCRN_ENCODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channel: int,
kernel_size: Tuple[int, int],
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 1),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
self.encoder = nn.Sequential(
ComplexConv2d(
in_channels,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
batchnorm(out_channel),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DCCRN_DECODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
layer: int = 0,
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 0),
output_padding: Tuple[int, int] = (1, 0),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
if layer != 0:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
),
batchnorm(out_channels),
activation,
)
else:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
)
)
def forward(self, waveform):
return self.decoder(waveform)
class DCCRN(Model):
STFT_DEFAULTS = {
"window_len": 400,
"hop_size": 100,
"nfft": 512,
"window": "hamming",
}
ED_DEFAULTS = {
"initial_output_channels": 32,
"depth": 6,
"kernel_size": 5,
"growth_factor": 2,
"stride": 2,
"padding": 2,
"output_padding": 1,
}
LSTM_DEFAULTS = {
"num_layers": 2,
"hidden_size": 256,
}
def __init__(
self,
stft: Optional[dict] = None,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
complex_lstm: bool = True,
complex_norm: bool = True,
complex_relu: bool = True,
masking_mode: str = "E",
num_channels: int = 1,
sampling_rate=16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
stft = merge_dict(self.STFT_DEFAULTS, stft)
self.save_hyperparameters(
"encoder_decoder",
"lstm",
"stft",
"complex_lstm",
"complex_norm",
"masking_mode",
)
self.complex_lstm = complex_lstm
self.complex_norm = complex_norm
self.masking_mode = masking_mode
self.stft = ConvSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.istft = ConviSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
num_channels *= 2
hidden_size = encoder_decoder["initial_output_channels"]
growth_factor = 2
for layer in range(encoder_decoder["depth"]):
encoder_ = DCCRN_ENCODER(
num_channels,
hidden_size,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 1),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.encoder.append(encoder_)
decoder_ = DCCRN_DECODER(
hidden_size + hidden_size,
num_channels,
layer=layer,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 0),
output_padding=(encoder_decoder["output_padding"], 0),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.decoder.insert(0, decoder_)
if layer < encoder_decoder["depth"] - 3:
num_channels = hidden_size
hidden_size *= growth_factor
else:
num_channels = hidden_size
kernel_size = hidden_size / 2
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
if self.complex_lstm:
lstms = []
for layer in range(lstm["num_layers"]):
if layer == 0:
input_size = int(hidden_size * kernel_size)
else:
input_size = lstm["hidden_size"]
if layer == lstm["num_layers"] - 1:
projection_size = int(hidden_size * kernel_size)
else:
projection_size = None
kwargs = {
"input_size": input_size,
"hidden_size": lstm["hidden_size"],
"num_layers": 1,
}
lstms.append(
ComplexLSTM(projection_size=projection_size, **kwargs)
)
self.lstm = nn.Sequential(*lstms)
else:
self.lstm = nn.Sequential(
nn.LSTM(
input_size=hidden_size * kernel_size,
hidden_sizs=lstm["hidden_size"],
num_layers=lstm["num_layers"],
dropout=0.0,
batch_first=False,
)[0],
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
waveform_stft = self.stft(waveform)
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
phase_spec = torch.atan2(imag, real)
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
encoder_outputs = []
out = complex_spec
for _, encoder in enumerate(self.encoder):
out = encoder(out)
encoder_outputs.append(out)
B, C, D, T = out.size()
out = out.permute(3, 0, 1, 2)
if self.complex_lstm:
lstm_real = out[:, :, : C // 2]
lstm_imag = out[:, :, C // 2 :]
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
lstm_real = lstm_real.reshape(T, B, C // 2, D)
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
out = torch.cat([lstm_real, lstm_imag], 2)
else:
out = out.reshape(T, B, C * D)
out = self.lstm(out)
out = out.reshape(T, B, D, C)
out = out.permute(1, 2, 3, 0)
for layer, decoder in enumerate(self.decoder):
skip_connection = encoder_outputs.pop(-1)
out = complex_cat([skip_connection, out])
out = decoder(out)
out = out[..., 1:]
mask_real, mask_imag = out[:, 0], out[:, 1]
mask_real = F.pad(mask_real, [0, 0, 1, 0])
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
if self.masking_mode == "E":
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
real_phase = mask_real / (mask_mag + 1e-8)
imag_phase = mask_imag / (mask_mag + 1e-8)
mask_phase = torch.atan2(imag_phase, real_phase)
mask_mag = torch.tanh(mask_mag)
est_mag = mask_mag * mag_spec
est_phase = mask_phase * phase_spec
# cos(theta) + isin(theta)
real = est_mag + torch.cos(est_phase)
imag = est_mag + torch.sin(est_phase)
if self.masking_mode == "C":
real = real * mask_real - imag * mask_imag
imag = real * mask_imag + imag * mask_real
else:
real = real * mask_real
imag = imag * mask_imag
spec = torch.cat([real, imag], 1)
wav = self.istft(spec)
wav = wav.clamp_(-1, 1)
return wav

View File

@ -204,9 +204,9 @@ class Demucs(Model):
if waveform.dim() == 2: if waveform.dim() == 2:
waveform = waveform.unsqueeze(1) waveform = waveform.unsqueeze(1)
if waveform.size(1) != 1: if waveform.size(1) != self.hparams.num_channels:
raise TypeError( raise ValueError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
) )
if self.normalize: if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True) waveform = waveform.mean(dim=1, keepdim=True)

View File

@ -2,7 +2,7 @@ import os
from collections import defaultdict from collections import defaultdict
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from typing import List, Optional, Text, Union from typing import Any, List, Optional, Text, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@ -10,6 +10,7 @@ import pytorch_lightning as pl
import torch import torch
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
@ -36,7 +37,7 @@ class Model(pl.LightningModule):
Enhancer dataset used for training/validation Enhancer dataset used for training/validation
duration: float, optional duration: float, optional
duration used for training/inference duration used for training/inference
loss : string or List of strings, default to "mse" loss : string or List of strings or custom loss (nn.Module), default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR") loss functions to be used. Available ("mse","mae","Si-SDR")
""" """
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
dataset: Optional[EnhancerDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List, Any] = "mse",
): ):
super().__init__() super().__init__()
assert ( assert (
@ -86,10 +87,11 @@ class Model(pl.LightningModule):
@metric.setter @metric.setter
def metric(self, metric): def metric(self, metric):
self._metric = [] self._metric = []
if isinstance(metric, str): if isinstance(metric, (str, nn.Module)):
metric = [metric] metric = [metric]
for func in metric: for func in metric:
if isinstance(func, str):
if func in LOSS_MAP.keys(): if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"): if func in ("pesq", "stoi"):
self._metric.append( self._metric.append(
@ -97,9 +99,13 @@ class Model(pl.LightningModule):
) )
else: else:
self._metric.append(LOSS_MAP[func]()) self._metric.append(LOSS_MAP[func]())
else: else:
raise ValueError(f"Invalid metrics {func}") ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else:
raise ValueError("Invalid metrics")
@property @property
def dataset(self): def dataset(self):

View File

@ -0,0 +1,92 @@
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import get_window
from torch import nn
class ConvFFT(nn.Module):
def __init__(
self,
window_len: int,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__()
self.window_len = window_len
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
self.window = torch.from_numpy(
get_window(window, window_len, fftbins=True).astype("float32")
)
def init_kernel(self, inverse=False):
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
real, imag = np.real(fourier_basis), np.imag(fourier_basis)
kernel = np.concatenate([real, imag], 1).T
if inverse:
kernel = np.linalg.pinv(kernel).T
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
kernel *= self.window
return kernel
class ConvSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel())
def forward(self, input):
if input.dim() < 2:
raise ValueError(
f"Expected signal with shape 2 or 3 got {input.dim()}"
)
elif input.dim() == 2:
input = input.unsqueeze(1)
else:
pass
input = F.pad(
input,
(self.window_len - self.hop_size, self.window_len - self.hop_size),
)
output = F.conv1d(input, self.weight, stride=self.hop_size)
return output
class ConviSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel(True))
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
def forward(self, input, phase=None):
if phase is not None:
real = input * torch.cos(phase)
imag = input * torch.sin(phase)
input = torch.cat([real, imag], 1)
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
out = out / (coeff + 1e-8)
pad = self.window_len - self.hop_size
out = out[..., pad:-pad]
return out

View File

@ -0,0 +1,50 @@
import torch
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from enhancer.models.complexnn.rnn import ComplexLSTM
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
def test_complexconv2d():
sample_input = torch.rand(1, 2, 256, 13)
conv = ComplexConv2d(
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 13])
def test_complexconvtranspose2d():
sample_input = torch.rand(1, 512, 4, 13)
conv = ComplexConvTranspose2d(
256 * 2,
128 * 2,
kernel_size=(5, 2),
stride=(2, 1),
padding=(2, 0),
output_padding=(1, 0),
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 256, 8, 14])
def test_complexlstm():
sample_input = torch.rand(13, 2, 128)
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
with torch.no_grad():
out = lstm(sample_input)
assert out[0].shape == torch.Size([13, 1, 512])
assert out[1].shape == torch.Size([13, 1, 512])
def test_complexbatchnorm2d():
sample_input = torch.rand(1, 64, 64, 14)
batchnorm = ComplexBatchNorm2D(num_features=64)
with torch.no_grad():
out = batchnorm(sample_input)
assert out.size() == sample_input.size()

View File

@ -30,7 +30,7 @@ def test_forward(batch_size, samples):
data = torch.rand(batch_size, 2, samples, requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(ValueError):
_ = model(data) _ = model(data)

View File

@ -0,0 +1,43 @@
import pytest
import torch
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.dccrn import DCCRN
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = DCCRN()
model.eval()
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(ValueError):
_ = model(data)
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
_ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)

18
tests/transforms_test.py Normal file
View File

@ -0,0 +1,18 @@
import torch
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
def test_stft_istft():
sample_input = torch.rand(1, 1, 16000)
stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
with torch.no_grad():
spectrogram = stft(sample_input)
waveform = istft(spectrogram)
assert sample_input.shape == waveform.shape
assert (
torch.isclose(waveform, sample_input).sum().item()
> sample_input.shape[-1] // 2
)