Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-11-07 11:54:25 +05:30
commit ce04720e59
14 changed files with 1011 additions and 25 deletions

View File

@ -0,0 +1,25 @@
_target_: enhancer.models.dccrn.DCCRN
num_channels: 1
sampling_rate : 16000
complex_lstm : True
complex_norm : True
complex_relu : True
masking_mode : True
encoder_decoder:
initial_output_channels : 32
depth : 6
kernel_size : 5
growth_factor : 2
stride : 2
padding : 2
output_padding : 1
lstm:
num_layers : 2
hidden_size : 256
stft:
window_len : 400
hop_size : 100
nfft : 512

View File

@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule):
name: str,
root_dir: str,
files: Files,
valid_minutes: float = 0.20,
min_valid_minutes: float = 0.20,
duration: float = 1.0,
stride=None,
sampling_rate: int = 48000,
@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
if valid_minutes > 0.0:
self.valid_minutes = valid_minutes
if min_valid_minutes > 0.0:
self.min_valid_minutes = min_valid_minutes
else:
raise ValueError("valid_minutes must be greater than 0")
raise ValueError("min_valid_minutes must be greater than 0")
self.augmentations = augmentations
@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule):
)
train_data = fp.prepare_matching_dict()
train_data, self.val_data = self.train_valid_split(
train_data, valid_minutes=self.valid_minutes, random_state=42
train_data,
min_valid_minutes=self.min_valid_minutes,
random_state=42,
)
self.train_data = self.prepare_traindata(train_data)
@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule):
self._test = self.prepare_mapstype(test_data)
def train_valid_split(
self, data, valid_minutes: float = 20, random_state: int = 42
self, data, min_valid_minutes: float = 20, random_state: int = 42
):
valid_minutes *= 60
min_valid_minutes *= 60
valid_sec_now = 0.0
valid_indices = []
all_speakers = np.unique(
@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule):
possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes:
while valid_sec_now <= min_valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
@ -257,6 +259,9 @@ class EnhancerDataset(TaskDataset):
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass)
min_valid_minutes: float
minimum validation split size time in minutes
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
@ -271,6 +276,7 @@ class EnhancerDataset(TaskDataset):
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
@ -278,7 +284,7 @@ class EnhancerDataset(TaskDataset):
name: str,
root_dir: str,
files: Files,
valid_minutes=5.0,
min_valid_minutes=5.0,
duration=1.0,
stride=None,
sampling_rate=48000,
@ -292,7 +298,7 @@ class EnhancerDataset(TaskDataset):
name=name,
root_dir=root_dir,
files=files,
valid_minutes=valid_minutes,
min_valid_minutes=min_valid_minutes,
sampling_rate=sampling_rate,
duration=duration,
matching_function=matching_function,

View File

@ -0,0 +1,5 @@
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
from enhancer.models.complexnn.utils import ComplexRelu # noqa

View File

@ -0,0 +1,136 @@
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
def init_weights(nnet):
nn.init.xavier_normal_(nnet.weight.data)
nn.init.constant_(nnet.bias, 0.0)
return nnet
class ComplexConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
groups: int = 1,
dilation: int = 1,
):
"""
Complex Conv2d (non-causal)
"""
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.dilation = dilation
self.real_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.padding[0], 0),
groups=self.groups,
dilation=self.dilation,
)
self.imag_conv = init_weights(self.imag_conv)
self.real_conv = init_weights(self.real_conv)
def forward(self, input):
"""
complex axis should be always 1 dim
"""
input = F.pad(input, [self.padding[1], 0, 0, 0])
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out
class ComplexConvTranspose2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int] = (1, 1),
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
output_padding: Tuple[int, int] = (0, 0),
groups: int = 1,
):
super().__init__()
self.in_channels = in_channels // 2
self.out_channels = out_channels // 2
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
self.output_padding = output_padding
self.real_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.imag_conv = nn.ConvTranspose2d(
self.in_channels,
self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
groups=self.groups,
)
self.real_conv = init_weights(self.real_conv)
self.imag_conv = init_weights(self.imag_conv)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_conv(real)
real_imag = self.imag_conv(real)
imag_imag = self.imag_conv(imag)
imag_real = self.real_conv(imag)
real = real_real - imag_imag
imag = real_imag - imag_real
out = torch.cat([real, imag], 1)
return out

View File

@ -0,0 +1,68 @@
from typing import List, Optional
import torch
from torch import nn
class ComplexLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
projection_size: Optional[int] = None,
bidirectional: bool = False,
):
super().__init__()
self.input_size = input_size // 2
self.hidden_size = hidden_size // 2
self.num_layers = num_layers
self.real_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
self.imag_lstm = nn.LSTM(
self.input_size,
self.hidden_size,
self.num_layers,
bidirectional=bidirectional,
batch_first=False,
)
bidirectional = 2 if bidirectional else 1
if projection_size is not None:
self.projection_size = projection_size // 2
self.real_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
self.imag_linear = nn.Linear(
self.hidden_size * bidirectional, self.projection_size
)
else:
self.projection_size = None
def forward(self, input):
if isinstance(input, List):
real, imag = input
else:
real, imag = torch.chunk(input, 2, 1)
real_real = self.real_lstm(real)[0]
real_imag = self.imag_lstm(real)[0]
imag_imag = self.imag_lstm(imag)[0]
imag_real = self.real_lstm(imag)[0]
real = real_real - imag_imag
imag = imag_real + real_imag
if self.projection_size is not None:
real = self.real_linear(real)
imag = self.imag_linear(imag)
return [real, imag]

View File

@ -0,0 +1,199 @@
import torch
from torch import nn
class ComplexBatchNorm2D(nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
):
"""
Complex batch normalization 2D
https://arxiv.org/abs/1705.09792
"""
super().__init__()
self.num_features = num_features // 2
self.affine = affine
self.momentum = momentum
self.track_running_stats = track_running_stats
self.eps = eps
if self.affine:
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
else:
self.register_parameter("Wrr", None)
self.register_parameter("Wri", None)
self.register_parameter("Wii", None)
self.register_parameter("Br", None)
self.register_parameter("Bi", None)
if self.track_running_stats:
values = torch.zeros(self.num_features)
self.register_buffer("Mean_real", values)
self.register_buffer("Mean_imag", values)
self.register_buffer("Var_rr", values)
self.register_buffer("Var_ri", values)
self.register_buffer("Var_ii", values)
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_parameter("Mean_real", None)
self.register_parameter("Mean_imag", None)
self.register_parameter("Var_rr", None)
self.register_parameter("Var_ri", None)
self.register_parameter("Var_ii", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
def reset_parameters(self):
if self.affine:
self.Wrr.data.fill_(1)
self.Wii.data.fill_(1)
self.Wri.data.uniform_(-0.9, 0.9)
self.Br.data.fill_(0)
self.Bi.data.fill_(0)
self.reset_running_stats()
def reset_running_stats(self):
if self.track_running_stats:
self.Mean_real.zero_()
self.Mean_imag.zero_()
self.Var_rr.fill_(1)
self.Var_ri.zero_()
self.Var_ii.fill_(1)
self.num_batches_tracked.zero_()
def extra_repr(self):
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
**self.__dict__
)
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
exp_avg_factor = 0.0
training = self.training and self.track_running_stats
if training:
self.num_batches_tracked += 1
if self.momentum is None:
exp_avg_factor = 1 / self.num_batches_tracked
else:
exp_avg_factor = self.momentum
redux = [i for i in reversed(range(real.dim())) if i != 1]
vdim = [1] * real.dim()
vdim[1] = real.size(1)
if training:
batch_mean_real, batch_mean_imag = real, imag
for dim in redux:
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
if self.track_running_stats:
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
else:
batch_mean_real = self.Mean_real.view(vdim)
batch_mean_imag = self.Mean_imag.view(vdim)
real = real - batch_mean_real
imag = imag - batch_mean_imag
if training:
batch_var_rr = real * real
batch_var_ri = real * imag
batch_var_ii = imag * imag
for dim in redux:
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
if self.track_running_stats:
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
else:
batch_var_rr = self.Var_rr.view(vdim)
batch_var_ii = self.Var_ii.view(vdim)
batch_var_ri = self.Var_ri.view(vdim)
batch_var_rr += self.eps
batch_var_ii += self.eps
# Covariance matrics
# | batch_var_rr batch_var_ri |
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
# Inverse square root of cov matrix by combining below two formulas
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
# https://mathworld.wolfram.com/MatrixInverse.html
tau = batch_var_rr + batch_var_ii
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
t = (tau + 2 * s).sqrt()
rst = (s * t).reciprocal()
Urr = (batch_var_ii + s) * rst
Uri = -batch_var_ri * rst
Uii = (batch_var_rr + s) * rst
if self.affine:
Wrr, Wri, Wii = (
self.Wrr.view(vdim),
self.Wri.view(vdim),
self.Wii.view(vdim),
)
Zrr = (Wrr * Urr) + (Wri * Uri)
Zri = (Wrr * Uri) + (Wri * Uii)
Zir = (Wii * Uri) + (Wri * Urr)
Zii = (Wri * Uri) + (Wii * Uii)
else:
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
yr = (Zrr * real) + (Zri * imag)
yi = (Zir * real) + (Zii * imag)
if self.affine:
yr = yr + self.Br.view(vdim)
yi = yi + self.Bi.view(vdim)
outputs = torch.cat([yr, yi], 1)
return outputs
class ComplexRelu(nn.Module):
def __init__(self):
super().__init__()
self.real_relu = nn.PReLU()
self.imag_relu = nn.PReLU()
def forward(self, input):
real, imag = torch.chunk(input, 2, 1)
real = self.real_relu(real)
imag = self.imag_relu(imag)
return torch.cat([real, imag], dim=1)
def complex_cat(inputs, axis=1):
real, imag = [], []
for data in inputs:
real_data, imag_data = torch.chunk(data, 2, axis)
real.append(real_data)
imag.append(imag_data)
real = torch.cat(real, axis)
imag = torch.cat(imag, axis)
return torch.cat([real, imag], axis)

338
enhancer/models/dccrn.py Normal file
View File

@ -0,0 +1,338 @@
import logging
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from enhancer.data import EnhancerDataset
from enhancer.models import Model
from enhancer.models.complexnn import (
ComplexBatchNorm2D,
ComplexConv2d,
ComplexConvTranspose2d,
ComplexLSTM,
ComplexRelu,
)
from enhancer.models.complexnn.utils import complex_cat
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
from enhancer.utils.utils import merge_dict
class DCCRN_ENCODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channel: int,
kernel_size: Tuple[int, int],
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 1),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
self.encoder = nn.Sequential(
ComplexConv2d(
in_channels,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
batchnorm(out_channel),
activation,
)
def forward(self, waveform):
return self.encoder(waveform)
class DCCRN_DECODER(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple[int, int],
layer: int = 0,
complex_norm: bool = True,
complex_relu: bool = True,
stride: Tuple[int, int] = (2, 1),
padding: Tuple[int, int] = (2, 0),
output_padding: Tuple[int, int] = (1, 0),
):
super().__init__()
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
activation = ComplexRelu() if complex_relu else nn.PReLU()
if layer != 0:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
),
batchnorm(out_channels),
activation,
)
else:
self.decoder = nn.Sequential(
ComplexConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
)
)
def forward(self, waveform):
return self.decoder(waveform)
class DCCRN(Model):
STFT_DEFAULTS = {
"window_len": 400,
"hop_size": 100,
"nfft": 512,
"window": "hamming",
}
ED_DEFAULTS = {
"initial_output_channels": 32,
"depth": 6,
"kernel_size": 5,
"growth_factor": 2,
"stride": 2,
"padding": 2,
"output_padding": 1,
}
LSTM_DEFAULTS = {
"num_layers": 2,
"hidden_size": 256,
}
def __init__(
self,
stft: Optional[dict] = None,
encoder_decoder: Optional[dict] = None,
lstm: Optional[dict] = None,
complex_lstm: bool = True,
complex_norm: bool = True,
complex_relu: bool = True,
masking_mode: str = "E",
num_channels: int = 1,
sampling_rate=16000,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List, Any] = "mse",
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate
super().__init__(
num_channels=num_channels,
sampling_rate=sampling_rate,
lr=lr,
dataset=dataset,
duration=duration,
loss=loss,
metric=metric,
)
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
stft = merge_dict(self.STFT_DEFAULTS, stft)
self.save_hyperparameters(
"encoder_decoder",
"lstm",
"stft",
"complex_lstm",
"complex_norm",
"masking_mode",
)
self.complex_lstm = complex_lstm
self.complex_norm = complex_norm
self.masking_mode = masking_mode
self.stft = ConvSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.istft = ConviSTFT(
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
)
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
num_channels *= 2
hidden_size = encoder_decoder["initial_output_channels"]
growth_factor = 2
for layer in range(encoder_decoder["depth"]):
encoder_ = DCCRN_ENCODER(
num_channels,
hidden_size,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 1),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.encoder.append(encoder_)
decoder_ = DCCRN_DECODER(
hidden_size + hidden_size,
num_channels,
layer=layer,
kernel_size=(encoder_decoder["kernel_size"], 2),
stride=(encoder_decoder["stride"], 1),
padding=(encoder_decoder["padding"], 0),
output_padding=(encoder_decoder["output_padding"], 0),
complex_norm=complex_norm,
complex_relu=complex_relu,
)
self.decoder.insert(0, decoder_)
if layer < encoder_decoder["depth"] - 3:
num_channels = hidden_size
hidden_size *= growth_factor
else:
num_channels = hidden_size
kernel_size = hidden_size / 2
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
if self.complex_lstm:
lstms = []
for layer in range(lstm["num_layers"]):
if layer == 0:
input_size = int(hidden_size * kernel_size)
else:
input_size = lstm["hidden_size"]
if layer == lstm["num_layers"] - 1:
projection_size = int(hidden_size * kernel_size)
else:
projection_size = None
kwargs = {
"input_size": input_size,
"hidden_size": lstm["hidden_size"],
"num_layers": 1,
}
lstms.append(
ComplexLSTM(projection_size=projection_size, **kwargs)
)
self.lstm = nn.Sequential(*lstms)
else:
self.lstm = nn.Sequential(
nn.LSTM(
input_size=hidden_size * kernel_size,
hidden_sizs=lstm["hidden_size"],
num_layers=lstm["num_layers"],
dropout=0.0,
batch_first=False,
)[0],
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
)
def forward(self, waveform):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
waveform_stft = self.stft(waveform)
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
phase_spec = torch.atan2(imag, real)
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
encoder_outputs = []
out = complex_spec
for _, encoder in enumerate(self.encoder):
out = encoder(out)
encoder_outputs.append(out)
B, C, D, T = out.size()
out = out.permute(3, 0, 1, 2)
if self.complex_lstm:
lstm_real = out[:, :, : C // 2]
lstm_imag = out[:, :, C // 2 :]
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
lstm_real = lstm_real.reshape(T, B, C // 2, D)
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
out = torch.cat([lstm_real, lstm_imag], 2)
else:
out = out.reshape(T, B, C * D)
out = self.lstm(out)
out = out.reshape(T, B, D, C)
out = out.permute(1, 2, 3, 0)
for layer, decoder in enumerate(self.decoder):
skip_connection = encoder_outputs.pop(-1)
out = complex_cat([skip_connection, out])
out = decoder(out)
out = out[..., 1:]
mask_real, mask_imag = out[:, 0], out[:, 1]
mask_real = F.pad(mask_real, [0, 0, 1, 0])
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
if self.masking_mode == "E":
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
real_phase = mask_real / (mask_mag + 1e-8)
imag_phase = mask_imag / (mask_mag + 1e-8)
mask_phase = torch.atan2(imag_phase, real_phase)
mask_mag = torch.tanh(mask_mag)
est_mag = mask_mag * mag_spec
est_phase = mask_phase * phase_spec
# cos(theta) + isin(theta)
real = est_mag + torch.cos(est_phase)
imag = est_mag + torch.sin(est_phase)
if self.masking_mode == "C":
real = real * mask_real - imag * mask_imag
imag = real * mask_imag + imag * mask_real
else:
real = real * mask_real
imag = imag * mask_imag
spec = torch.cat([real, imag], 1)
wav = self.istft(spec)
wav = wav.clamp_(-1, 1)
return wav

View File

@ -204,9 +204,9 @@ class Demucs(Model):
if waveform.dim() == 2:
waveform = waveform.unsqueeze(1)
if waveform.size(1) != 1:
raise TypeError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
if waveform.size(1) != self.hparams.num_channels:
raise ValueError(
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
)
if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True)

View File

@ -2,7 +2,7 @@ import os
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import List, Optional, Text, Union
from typing import Any, List, Optional, Text, Union
from urllib.parse import urlparse
import numpy as np
@ -10,6 +10,7 @@ import pytorch_lightning as pl
import torch
from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset
@ -36,7 +37,7 @@ class Model(pl.LightningModule):
Enhancer dataset used for training/validation
duration: float, optional
duration used for training/inference
loss : string or List of strings, default to "mse"
loss : string or List of strings or custom loss (nn.Module), default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR")
"""
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
metric: Union[str, List, Any] = "mse",
):
super().__init__()
assert (
@ -86,10 +87,11 @@ class Model(pl.LightningModule):
@metric.setter
def metric(self, metric):
self._metric = []
if isinstance(metric, str):
if isinstance(metric, (str, nn.Module)):
metric = [metric]
for func in metric:
if isinstance(func, str):
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
@ -97,9 +99,13 @@ class Model(pl.LightningModule):
)
else:
self._metric.append(LOSS_MAP[func]())
else:
raise ValueError(f"Invalid metrics {func}")
ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else:
raise ValueError("Invalid metrics")
@property
def dataset(self):

View File

@ -0,0 +1,92 @@
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import get_window
from torch import nn
class ConvFFT(nn.Module):
def __init__(
self,
window_len: int,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__()
self.window_len = window_len
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
self.window = torch.from_numpy(
get_window(window, window_len, fftbins=True).astype("float32")
)
def init_kernel(self, inverse=False):
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
real, imag = np.real(fourier_basis), np.imag(fourier_basis)
kernel = np.concatenate([real, imag], 1).T
if inverse:
kernel = np.linalg.pinv(kernel).T
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
kernel *= self.window
return kernel
class ConvSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel())
def forward(self, input):
if input.dim() < 2:
raise ValueError(
f"Expected signal with shape 2 or 3 got {input.dim()}"
)
elif input.dim() == 2:
input = input.unsqueeze(1)
else:
pass
input = F.pad(
input,
(self.window_len - self.hop_size, self.window_len - self.hop_size),
)
output = F.conv1d(input, self.weight, stride=self.hop_size)
return output
class ConviSTFT(ConvFFT):
def __init__(
self,
window_len: int,
hop_size: Optional[int] = None,
nfft: Optional[int] = None,
window: str = "hamming",
):
super().__init__(window_len=window_len, nfft=nfft, window=window)
self.hop_size = hop_size if hop_size else window_len // 2
self.register_buffer("weight", self.init_kernel(True))
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
def forward(self, input, phase=None):
if phase is not None:
real = input * torch.cos(phase)
imag = input * torch.sin(phase)
input = torch.cat([real, imag], 1)
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
out = out / (coeff + 1e-8)
pad = self.window_len - self.hop_size
out = out[..., pad:-pad]
return out

View File

@ -0,0 +1,50 @@
import torch
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
from enhancer.models.complexnn.rnn import ComplexLSTM
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
def test_complexconv2d():
sample_input = torch.rand(1, 2, 256, 13)
conv = ComplexConv2d(
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 32, 128, 13])
def test_complexconvtranspose2d():
sample_input = torch.rand(1, 512, 4, 13)
conv = ComplexConvTranspose2d(
256 * 2,
128 * 2,
kernel_size=(5, 2),
stride=(2, 1),
padding=(2, 0),
output_padding=(1, 0),
)
with torch.no_grad():
out = conv(sample_input)
assert out.shape == torch.Size([1, 256, 8, 14])
def test_complexlstm():
sample_input = torch.rand(13, 2, 128)
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
with torch.no_grad():
out = lstm(sample_input)
assert out[0].shape == torch.Size([13, 1, 512])
assert out[1].shape == torch.Size([13, 1, 512])
def test_complexbatchnorm2d():
sample_input = torch.rand(1, 64, 64, 14)
batchnorm = ComplexBatchNorm2D(num_features=64)
with torch.no_grad():
out = batchnorm(sample_input)
assert out.size() == sample_input.size()

View File

@ -30,7 +30,7 @@ def test_forward(batch_size, samples):
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(TypeError):
with pytest.raises(ValueError):
_ = model(data)

View File

@ -0,0 +1,43 @@
import pytest
import torch
from enhancer.data.dataset import EnhancerDataset
from enhancer.models.dccrn import DCCRN
from enhancer.utils.config import Files
@pytest.fixture
def vctk_dataset():
root_dir = "tests/data/vctk"
files = Files(
train_clean="clean_testset_wav",
train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size, samples):
model = DCCRN()
model.eval()
data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad():
_ = model(data)
data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad():
with pytest.raises(ValueError):
_ = model(data)
@pytest.mark.parametrize(
"dataset,channels,loss",
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad():
_ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)

18
tests/transforms_test.py Normal file
View File

@ -0,0 +1,18 @@
import torch
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
def test_stft_istft():
sample_input = torch.rand(1, 1, 16000)
stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
with torch.no_grad():
spectrogram = stft(sample_input)
waveform = istft(spectrogram)
assert sample_input.shape == waveform.shape
assert (
torch.isclose(waveform, sample_input).sum().item()
> sample_input.shape[-1] // 2
)