Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
ce04720e59
|
|
@ -0,0 +1,25 @@
|
||||||
|
_target_: enhancer.models.dccrn.DCCRN
|
||||||
|
num_channels: 1
|
||||||
|
sampling_rate : 16000
|
||||||
|
complex_lstm : True
|
||||||
|
complex_norm : True
|
||||||
|
complex_relu : True
|
||||||
|
masking_mode : True
|
||||||
|
|
||||||
|
encoder_decoder:
|
||||||
|
initial_output_channels : 32
|
||||||
|
depth : 6
|
||||||
|
kernel_size : 5
|
||||||
|
growth_factor : 2
|
||||||
|
stride : 2
|
||||||
|
padding : 2
|
||||||
|
output_padding : 1
|
||||||
|
|
||||||
|
lstm:
|
||||||
|
num_layers : 2
|
||||||
|
hidden_size : 256
|
||||||
|
|
||||||
|
stft:
|
||||||
|
window_len : 400
|
||||||
|
hop_size : 100
|
||||||
|
nfft : 512
|
||||||
|
|
@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes: float = 0.20,
|
min_valid_minutes: float = 0.20,
|
||||||
duration: float = 1.0,
|
duration: float = 1.0,
|
||||||
stride=None,
|
stride=None,
|
||||||
sampling_rate: int = 48000,
|
sampling_rate: int = 48000,
|
||||||
|
|
@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
if valid_minutes > 0.0:
|
if min_valid_minutes > 0.0:
|
||||||
self.valid_minutes = valid_minutes
|
self.min_valid_minutes = min_valid_minutes
|
||||||
else:
|
else:
|
||||||
raise ValueError("valid_minutes must be greater than 0")
|
raise ValueError("min_valid_minutes must be greater than 0")
|
||||||
|
|
||||||
self.augmentations = augmentations
|
self.augmentations = augmentations
|
||||||
|
|
||||||
|
|
@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
train_data = fp.prepare_matching_dict()
|
train_data = fp.prepare_matching_dict()
|
||||||
train_data, self.val_data = self.train_valid_split(
|
train_data, self.val_data = self.train_valid_split(
|
||||||
train_data, valid_minutes=self.valid_minutes, random_state=42
|
train_data,
|
||||||
|
min_valid_minutes=self.min_valid_minutes,
|
||||||
|
random_state=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.train_data = self.prepare_traindata(train_data)
|
self.train_data = self.prepare_traindata(train_data)
|
||||||
|
|
@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self._test = self.prepare_mapstype(test_data)
|
self._test = self.prepare_mapstype(test_data)
|
||||||
|
|
||||||
def train_valid_split(
|
def train_valid_split(
|
||||||
self, data, valid_minutes: float = 20, random_state: int = 42
|
self, data, min_valid_minutes: float = 20, random_state: int = 42
|
||||||
):
|
):
|
||||||
|
|
||||||
valid_minutes *= 60
|
min_valid_minutes *= 60
|
||||||
valid_sec_now = 0.0
|
valid_sec_now = 0.0
|
||||||
valid_indices = []
|
valid_indices = []
|
||||||
all_speakers = np.unique(
|
all_speakers = np.unique(
|
||||||
|
|
@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
possible_indices = list(range(0, len(all_speakers)))
|
possible_indices = list(range(0, len(all_speakers)))
|
||||||
rng = create_unique_rng(len(all_speakers))
|
rng = create_unique_rng(len(all_speakers))
|
||||||
|
|
||||||
while valid_sec_now <= valid_minutes:
|
while valid_sec_now <= min_valid_minutes:
|
||||||
speaker_index = rng.choice(possible_indices)
|
speaker_index = rng.choice(possible_indices)
|
||||||
possible_indices.remove(speaker_index)
|
possible_indices.remove(speaker_index)
|
||||||
speaker_name = all_speakers[speaker_index]
|
speaker_name = all_speakers[speaker_index]
|
||||||
|
|
@ -257,6 +259,9 @@ class EnhancerDataset(TaskDataset):
|
||||||
files : Files
|
files : Files
|
||||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||||
folder names (refer enhancer.utils.Files dataclass)
|
folder names (refer enhancer.utils.Files dataclass)
|
||||||
|
min_valid_minutes: float
|
||||||
|
minimum validation split size time in minutes
|
||||||
|
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
||||||
duration : float
|
duration : float
|
||||||
expected audio duration of single audio sample for training
|
expected audio duration of single audio sample for training
|
||||||
sampling_rate : int
|
sampling_rate : int
|
||||||
|
|
@ -271,6 +276,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
use one_to_many mapping for multiple noisy files for each clean file
|
use one_to_many mapping for multiple noisy files for each clean file
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -278,7 +284,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes=5.0,
|
min_valid_minutes=5.0,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
stride=None,
|
stride=None,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
|
|
@ -292,7 +298,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name=name,
|
name=name,
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
files=files,
|
files=files,
|
||||||
valid_minutes=valid_minutes,
|
min_valid_minutes=min_valid_minutes,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
from enhancer.models.complexnn.conv import ComplexConv2d # noqa
|
||||||
|
from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa
|
||||||
|
from enhancer.models.complexnn.rnn import ComplexLSTM # noqa
|
||||||
|
from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa
|
||||||
|
from enhancer.models.complexnn.utils import ComplexRelu # noqa
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(nnet):
|
||||||
|
nn.init.xavier_normal_(nnet.weight.data)
|
||||||
|
nn.init.constant_(nnet.bias, 0.0)
|
||||||
|
return nnet
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexConv2d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Tuple[int, int] = (1, 1),
|
||||||
|
stride: Tuple[int, int] = (1, 1),
|
||||||
|
padding: Tuple[int, int] = (0, 0),
|
||||||
|
groups: int = 1,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Complex Conv2d (non-causal)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels // 2
|
||||||
|
self.out_channels = out_channels // 2
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.groups = groups
|
||||||
|
self.dilation = dilation
|
||||||
|
|
||||||
|
self.real_conv = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=(self.padding[0], 0),
|
||||||
|
groups=self.groups,
|
||||||
|
dilation=self.dilation,
|
||||||
|
)
|
||||||
|
self.imag_conv = nn.Conv2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=(self.padding[0], 0),
|
||||||
|
groups=self.groups,
|
||||||
|
dilation=self.dilation,
|
||||||
|
)
|
||||||
|
self.imag_conv = init_weights(self.imag_conv)
|
||||||
|
self.real_conv = init_weights(self.real_conv)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
"""
|
||||||
|
complex axis should be always 1 dim
|
||||||
|
"""
|
||||||
|
input = F.pad(input, [self.padding[1], 0, 0, 0])
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
||||||
|
real_real = self.real_conv(real)
|
||||||
|
real_imag = self.imag_conv(real)
|
||||||
|
|
||||||
|
imag_imag = self.imag_conv(imag)
|
||||||
|
imag_real = self.real_conv(imag)
|
||||||
|
|
||||||
|
real = real_real - imag_imag
|
||||||
|
imag = real_imag - imag_real
|
||||||
|
|
||||||
|
out = torch.cat([real, imag], 1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexConvTranspose2d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Tuple[int, int] = (1, 1),
|
||||||
|
stride: Tuple[int, int] = (1, 1),
|
||||||
|
padding: Tuple[int, int] = (0, 0),
|
||||||
|
output_padding: Tuple[int, int] = (0, 0),
|
||||||
|
groups: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels // 2
|
||||||
|
self.out_channels = out_channels // 2
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.groups = groups
|
||||||
|
self.output_padding = output_padding
|
||||||
|
|
||||||
|
self.real_conv = nn.ConvTranspose2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
output_padding=self.output_padding,
|
||||||
|
groups=self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.imag_conv = nn.ConvTranspose2d(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
output_padding=self.output_padding,
|
||||||
|
groups=self.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.real_conv = init_weights(self.real_conv)
|
||||||
|
self.imag_conv = init_weights(self.imag_conv)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
real_real = self.real_conv(real)
|
||||||
|
real_imag = self.imag_conv(real)
|
||||||
|
|
||||||
|
imag_imag = self.imag_conv(imag)
|
||||||
|
imag_real = self.real_conv(imag)
|
||||||
|
|
||||||
|
real = real_real - imag_imag
|
||||||
|
imag = real_imag - imag_real
|
||||||
|
|
||||||
|
out = torch.cat([real, imag], 1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexLSTM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_layers: int = 1,
|
||||||
|
projection_size: Optional[int] = None,
|
||||||
|
bidirectional: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size // 2
|
||||||
|
self.hidden_size = hidden_size // 2
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.real_lstm = nn.LSTM(
|
||||||
|
self.input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_layers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=False,
|
||||||
|
)
|
||||||
|
self.imag_lstm = nn.LSTM(
|
||||||
|
self.input_size,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_layers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
bidirectional = 2 if bidirectional else 1
|
||||||
|
if projection_size is not None:
|
||||||
|
self.projection_size = projection_size // 2
|
||||||
|
self.real_linear = nn.Linear(
|
||||||
|
self.hidden_size * bidirectional, self.projection_size
|
||||||
|
)
|
||||||
|
self.imag_linear = nn.Linear(
|
||||||
|
self.hidden_size * bidirectional, self.projection_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.projection_size = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
if isinstance(input, List):
|
||||||
|
real, imag = input
|
||||||
|
else:
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
|
||||||
|
real_real = self.real_lstm(real)[0]
|
||||||
|
real_imag = self.imag_lstm(real)[0]
|
||||||
|
|
||||||
|
imag_imag = self.imag_lstm(imag)[0]
|
||||||
|
imag_real = self.real_lstm(imag)[0]
|
||||||
|
|
||||||
|
real = real_real - imag_imag
|
||||||
|
imag = imag_real + real_imag
|
||||||
|
|
||||||
|
if self.projection_size is not None:
|
||||||
|
real = self.real_linear(real)
|
||||||
|
imag = self.imag_linear(imag)
|
||||||
|
|
||||||
|
return [real, imag]
|
||||||
|
|
@ -0,0 +1,199 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexBatchNorm2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_features: int,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
momentum: float = 0.1,
|
||||||
|
affine: bool = True,
|
||||||
|
track_running_stats: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Complex batch normalization 2D
|
||||||
|
https://arxiv.org/abs/1705.09792
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = num_features // 2
|
||||||
|
self.affine = affine
|
||||||
|
self.momentum = momentum
|
||||||
|
self.track_running_stats = track_running_stats
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features))
|
||||||
|
else:
|
||||||
|
self.register_parameter("Wrr", None)
|
||||||
|
self.register_parameter("Wri", None)
|
||||||
|
self.register_parameter("Wii", None)
|
||||||
|
self.register_parameter("Br", None)
|
||||||
|
self.register_parameter("Bi", None)
|
||||||
|
|
||||||
|
if self.track_running_stats:
|
||||||
|
values = torch.zeros(self.num_features)
|
||||||
|
self.register_buffer("Mean_real", values)
|
||||||
|
self.register_buffer("Mean_imag", values)
|
||||||
|
self.register_buffer("Var_rr", values)
|
||||||
|
self.register_buffer("Var_ri", values)
|
||||||
|
self.register_buffer("Var_ii", values)
|
||||||
|
self.register_buffer(
|
||||||
|
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("Mean_real", None)
|
||||||
|
self.register_parameter("Mean_imag", None)
|
||||||
|
self.register_parameter("Var_rr", None)
|
||||||
|
self.register_parameter("Var_ri", None)
|
||||||
|
self.register_parameter("Var_ii", None)
|
||||||
|
self.register_parameter("num_batches_tracked", None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if self.affine:
|
||||||
|
self.Wrr.data.fill_(1)
|
||||||
|
self.Wii.data.fill_(1)
|
||||||
|
self.Wri.data.uniform_(-0.9, 0.9)
|
||||||
|
self.Br.data.fill_(0)
|
||||||
|
self.Bi.data.fill_(0)
|
||||||
|
self.reset_running_stats()
|
||||||
|
|
||||||
|
def reset_running_stats(self):
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.Mean_real.zero_()
|
||||||
|
self.Mean_imag.zero_()
|
||||||
|
self.Var_rr.fill_(1)
|
||||||
|
self.Var_ri.zero_()
|
||||||
|
self.Var_ii.fill_(1)
|
||||||
|
self.num_batches_tracked.zero_()
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format(
|
||||||
|
**self.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
exp_avg_factor = 0.0
|
||||||
|
|
||||||
|
training = self.training and self.track_running_stats
|
||||||
|
if training:
|
||||||
|
self.num_batches_tracked += 1
|
||||||
|
if self.momentum is None:
|
||||||
|
exp_avg_factor = 1 / self.num_batches_tracked
|
||||||
|
else:
|
||||||
|
exp_avg_factor = self.momentum
|
||||||
|
|
||||||
|
redux = [i for i in reversed(range(real.dim())) if i != 1]
|
||||||
|
vdim = [1] * real.dim()
|
||||||
|
vdim[1] = real.size(1)
|
||||||
|
|
||||||
|
if training:
|
||||||
|
batch_mean_real, batch_mean_imag = real, imag
|
||||||
|
for dim in redux:
|
||||||
|
batch_mean_real = batch_mean_real.mean(dim, keepdim=True)
|
||||||
|
batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True)
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor)
|
||||||
|
self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor)
|
||||||
|
|
||||||
|
else:
|
||||||
|
batch_mean_real = self.Mean_real.view(vdim)
|
||||||
|
batch_mean_imag = self.Mean_imag.view(vdim)
|
||||||
|
|
||||||
|
real = real - batch_mean_real
|
||||||
|
imag = imag - batch_mean_imag
|
||||||
|
|
||||||
|
if training:
|
||||||
|
batch_var_rr = real * real
|
||||||
|
batch_var_ri = real * imag
|
||||||
|
batch_var_ii = imag * imag
|
||||||
|
for dim in redux:
|
||||||
|
batch_var_rr = batch_var_rr.mean(dim, keepdim=True)
|
||||||
|
batch_var_ri = batch_var_ri.mean(dim, keepdim=True)
|
||||||
|
batch_var_ii = batch_var_ii.mean(dim, keepdim=True)
|
||||||
|
if self.track_running_stats:
|
||||||
|
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
|
||||||
|
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
|
||||||
|
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
|
||||||
|
else:
|
||||||
|
batch_var_rr = self.Var_rr.view(vdim)
|
||||||
|
batch_var_ii = self.Var_ii.view(vdim)
|
||||||
|
batch_var_ri = self.Var_ri.view(vdim)
|
||||||
|
|
||||||
|
batch_var_rr += self.eps
|
||||||
|
batch_var_ii += self.eps
|
||||||
|
|
||||||
|
# Covariance matrics
|
||||||
|
# | batch_var_rr batch_var_ri |
|
||||||
|
# | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri
|
||||||
|
# Inverse square root of cov matrix by combining below two formulas
|
||||||
|
# https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
|
||||||
|
# https://mathworld.wolfram.com/MatrixInverse.html
|
||||||
|
|
||||||
|
tau = batch_var_rr + batch_var_ii
|
||||||
|
s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri
|
||||||
|
t = (tau + 2 * s).sqrt()
|
||||||
|
|
||||||
|
rst = (s * t).reciprocal()
|
||||||
|
Urr = (batch_var_ii + s) * rst
|
||||||
|
Uri = -batch_var_ri * rst
|
||||||
|
Uii = (batch_var_rr + s) * rst
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
Wrr, Wri, Wii = (
|
||||||
|
self.Wrr.view(vdim),
|
||||||
|
self.Wri.view(vdim),
|
||||||
|
self.Wii.view(vdim),
|
||||||
|
)
|
||||||
|
Zrr = (Wrr * Urr) + (Wri * Uri)
|
||||||
|
Zri = (Wrr * Uri) + (Wri * Uii)
|
||||||
|
Zir = (Wii * Uri) + (Wri * Urr)
|
||||||
|
Zii = (Wri * Uri) + (Wii * Uii)
|
||||||
|
else:
|
||||||
|
Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
|
||||||
|
|
||||||
|
yr = (Zrr * real) + (Zri * imag)
|
||||||
|
yi = (Zir * real) + (Zii * imag)
|
||||||
|
|
||||||
|
if self.affine:
|
||||||
|
yr = yr + self.Br.view(vdim)
|
||||||
|
yi = yi + self.Bi.view(vdim)
|
||||||
|
|
||||||
|
outputs = torch.cat([yr, yi], 1)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class ComplexRelu(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.real_relu = nn.PReLU()
|
||||||
|
self.imag_relu = nn.PReLU()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
real, imag = torch.chunk(input, 2, 1)
|
||||||
|
real = self.real_relu(real)
|
||||||
|
imag = self.imag_relu(imag)
|
||||||
|
return torch.cat([real, imag], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def complex_cat(inputs, axis=1):
|
||||||
|
|
||||||
|
real, imag = [], []
|
||||||
|
for data in inputs:
|
||||||
|
real_data, imag_data = torch.chunk(data, 2, axis)
|
||||||
|
real.append(real_data)
|
||||||
|
imag.append(imag_data)
|
||||||
|
real = torch.cat(real, axis)
|
||||||
|
imag = torch.cat(imag, axis)
|
||||||
|
return torch.cat([real, imag], axis)
|
||||||
|
|
@ -0,0 +1,338 @@
|
||||||
|
import logging
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from enhancer.data import EnhancerDataset
|
||||||
|
from enhancer.models import Model
|
||||||
|
from enhancer.models.complexnn import (
|
||||||
|
ComplexBatchNorm2D,
|
||||||
|
ComplexConv2d,
|
||||||
|
ComplexConvTranspose2d,
|
||||||
|
ComplexLSTM,
|
||||||
|
ComplexRelu,
|
||||||
|
)
|
||||||
|
from enhancer.models.complexnn.utils import complex_cat
|
||||||
|
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
||||||
|
from enhancer.utils.utils import merge_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DCCRN_ENCODER(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channel: int,
|
||||||
|
kernel_size: Tuple[int, int],
|
||||||
|
complex_norm: bool = True,
|
||||||
|
complex_relu: bool = True,
|
||||||
|
stride: Tuple[int, int] = (2, 1),
|
||||||
|
padding: Tuple[int, int] = (2, 1),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
||||||
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
ComplexConv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channel,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
),
|
||||||
|
batchnorm(out_channel),
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
return self.encoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class DCCRN_DECODER(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
kernel_size: Tuple[int, int],
|
||||||
|
layer: int = 0,
|
||||||
|
complex_norm: bool = True,
|
||||||
|
complex_relu: bool = True,
|
||||||
|
stride: Tuple[int, int] = (2, 1),
|
||||||
|
padding: Tuple[int, int] = (2, 0),
|
||||||
|
output_padding: Tuple[int, int] = (1, 0),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
|
||||||
|
activation = ComplexRelu() if complex_relu else nn.PReLU()
|
||||||
|
|
||||||
|
if layer != 0:
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
ComplexConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
),
|
||||||
|
batchnorm(out_channels),
|
||||||
|
activation,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.decoder = nn.Sequential(
|
||||||
|
ComplexConvTranspose2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
output_padding=output_padding,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
return self.decoder(waveform)
|
||||||
|
|
||||||
|
|
||||||
|
class DCCRN(Model):
|
||||||
|
|
||||||
|
STFT_DEFAULTS = {
|
||||||
|
"window_len": 400,
|
||||||
|
"hop_size": 100,
|
||||||
|
"nfft": 512,
|
||||||
|
"window": "hamming",
|
||||||
|
}
|
||||||
|
|
||||||
|
ED_DEFAULTS = {
|
||||||
|
"initial_output_channels": 32,
|
||||||
|
"depth": 6,
|
||||||
|
"kernel_size": 5,
|
||||||
|
"growth_factor": 2,
|
||||||
|
"stride": 2,
|
||||||
|
"padding": 2,
|
||||||
|
"output_padding": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
LSTM_DEFAULTS = {
|
||||||
|
"num_layers": 2,
|
||||||
|
"hidden_size": 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stft: Optional[dict] = None,
|
||||||
|
encoder_decoder: Optional[dict] = None,
|
||||||
|
lstm: Optional[dict] = None,
|
||||||
|
complex_lstm: bool = True,
|
||||||
|
complex_norm: bool = True,
|
||||||
|
complex_relu: bool = True,
|
||||||
|
masking_mode: str = "E",
|
||||||
|
num_channels: int = 1,
|
||||||
|
sampling_rate=16000,
|
||||||
|
lr: float = 1e-3,
|
||||||
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
loss: Union[str, List, Any] = "mse",
|
||||||
|
metric: Union[str, List] = "mse",
|
||||||
|
):
|
||||||
|
duration = (
|
||||||
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
)
|
||||||
|
if dataset is not None:
|
||||||
|
if sampling_rate != dataset.sampling_rate:
|
||||||
|
logging.warning(
|
||||||
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
|
)
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
|
super().__init__(
|
||||||
|
num_channels=num_channels,
|
||||||
|
sampling_rate=sampling_rate,
|
||||||
|
lr=lr,
|
||||||
|
dataset=dataset,
|
||||||
|
duration=duration,
|
||||||
|
loss=loss,
|
||||||
|
metric=metric,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
|
||||||
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
|
stft = merge_dict(self.STFT_DEFAULTS, stft)
|
||||||
|
self.save_hyperparameters(
|
||||||
|
"encoder_decoder",
|
||||||
|
"lstm",
|
||||||
|
"stft",
|
||||||
|
"complex_lstm",
|
||||||
|
"complex_norm",
|
||||||
|
"masking_mode",
|
||||||
|
)
|
||||||
|
self.complex_lstm = complex_lstm
|
||||||
|
self.complex_norm = complex_norm
|
||||||
|
self.masking_mode = masking_mode
|
||||||
|
|
||||||
|
self.stft = ConvSTFT(
|
||||||
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
||||||
|
)
|
||||||
|
self.istft = ConviSTFT(
|
||||||
|
stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = nn.ModuleList()
|
||||||
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
num_channels *= 2
|
||||||
|
hidden_size = encoder_decoder["initial_output_channels"]
|
||||||
|
growth_factor = 2
|
||||||
|
|
||||||
|
for layer in range(encoder_decoder["depth"]):
|
||||||
|
|
||||||
|
encoder_ = DCCRN_ENCODER(
|
||||||
|
num_channels,
|
||||||
|
hidden_size,
|
||||||
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
||||||
|
stride=(encoder_decoder["stride"], 1),
|
||||||
|
padding=(encoder_decoder["padding"], 1),
|
||||||
|
complex_norm=complex_norm,
|
||||||
|
complex_relu=complex_relu,
|
||||||
|
)
|
||||||
|
self.encoder.append(encoder_)
|
||||||
|
|
||||||
|
decoder_ = DCCRN_DECODER(
|
||||||
|
hidden_size + hidden_size,
|
||||||
|
num_channels,
|
||||||
|
layer=layer,
|
||||||
|
kernel_size=(encoder_decoder["kernel_size"], 2),
|
||||||
|
stride=(encoder_decoder["stride"], 1),
|
||||||
|
padding=(encoder_decoder["padding"], 0),
|
||||||
|
output_padding=(encoder_decoder["output_padding"], 0),
|
||||||
|
complex_norm=complex_norm,
|
||||||
|
complex_relu=complex_relu,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder.insert(0, decoder_)
|
||||||
|
|
||||||
|
if layer < encoder_decoder["depth"] - 3:
|
||||||
|
num_channels = hidden_size
|
||||||
|
hidden_size *= growth_factor
|
||||||
|
else:
|
||||||
|
num_channels = hidden_size
|
||||||
|
|
||||||
|
kernel_size = hidden_size / 2
|
||||||
|
hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
|
||||||
|
|
||||||
|
if self.complex_lstm:
|
||||||
|
lstms = []
|
||||||
|
for layer in range(lstm["num_layers"]):
|
||||||
|
|
||||||
|
if layer == 0:
|
||||||
|
input_size = int(hidden_size * kernel_size)
|
||||||
|
else:
|
||||||
|
input_size = lstm["hidden_size"]
|
||||||
|
|
||||||
|
if layer == lstm["num_layers"] - 1:
|
||||||
|
projection_size = int(hidden_size * kernel_size)
|
||||||
|
else:
|
||||||
|
projection_size = None
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"input_size": input_size,
|
||||||
|
"hidden_size": lstm["hidden_size"],
|
||||||
|
"num_layers": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
lstms.append(
|
||||||
|
ComplexLSTM(projection_size=projection_size, **kwargs)
|
||||||
|
)
|
||||||
|
self.lstm = nn.Sequential(*lstms)
|
||||||
|
else:
|
||||||
|
self.lstm = nn.Sequential(
|
||||||
|
nn.LSTM(
|
||||||
|
input_size=hidden_size * kernel_size,
|
||||||
|
hidden_sizs=lstm["hidden_size"],
|
||||||
|
num_layers=lstm["num_layers"],
|
||||||
|
dropout=0.0,
|
||||||
|
batch_first=False,
|
||||||
|
)[0],
|
||||||
|
nn.Linear(lstm["hidden"], hidden_size * kernel_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, waveform):
|
||||||
|
|
||||||
|
if waveform.dim() == 2:
|
||||||
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
|
if waveform.size(1) != self.hparams.num_channels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||||
|
)
|
||||||
|
|
||||||
|
waveform_stft = self.stft(waveform)
|
||||||
|
real = waveform_stft[:, : self.stft.nfft // 2 + 1]
|
||||||
|
imag = waveform_stft[:, self.stft.nfft // 2 + 1 :]
|
||||||
|
|
||||||
|
mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9)
|
||||||
|
phase_spec = torch.atan2(imag, real)
|
||||||
|
complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:]
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
out = complex_spec
|
||||||
|
for _, encoder in enumerate(self.encoder):
|
||||||
|
out = encoder(out)
|
||||||
|
encoder_outputs.append(out)
|
||||||
|
|
||||||
|
B, C, D, T = out.size()
|
||||||
|
out = out.permute(3, 0, 1, 2)
|
||||||
|
if self.complex_lstm:
|
||||||
|
|
||||||
|
lstm_real = out[:, :, : C // 2]
|
||||||
|
lstm_imag = out[:, :, C // 2 :]
|
||||||
|
lstm_real = lstm_real.reshape(T, B, C // 2 * D)
|
||||||
|
lstm_imag = lstm_imag.reshape(T, B, C // 2 * D)
|
||||||
|
lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag])
|
||||||
|
lstm_real = lstm_real.reshape(T, B, C // 2, D)
|
||||||
|
lstm_imag = lstm_imag.reshape(T, B, C // 2, D)
|
||||||
|
out = torch.cat([lstm_real, lstm_imag], 2)
|
||||||
|
else:
|
||||||
|
out = out.reshape(T, B, C * D)
|
||||||
|
out = self.lstm(out)
|
||||||
|
out = out.reshape(T, B, D, C)
|
||||||
|
|
||||||
|
out = out.permute(1, 2, 3, 0)
|
||||||
|
for layer, decoder in enumerate(self.decoder):
|
||||||
|
skip_connection = encoder_outputs.pop(-1)
|
||||||
|
out = complex_cat([skip_connection, out])
|
||||||
|
out = decoder(out)
|
||||||
|
out = out[..., 1:]
|
||||||
|
mask_real, mask_imag = out[:, 0], out[:, 1]
|
||||||
|
mask_real = F.pad(mask_real, [0, 0, 1, 0])
|
||||||
|
mask_imag = F.pad(mask_imag, [0, 0, 1, 0])
|
||||||
|
if self.masking_mode == "E":
|
||||||
|
|
||||||
|
mask_mag = torch.sqrt(mask_real**2 + mask_imag**2)
|
||||||
|
real_phase = mask_real / (mask_mag + 1e-8)
|
||||||
|
imag_phase = mask_imag / (mask_mag + 1e-8)
|
||||||
|
mask_phase = torch.atan2(imag_phase, real_phase)
|
||||||
|
mask_mag = torch.tanh(mask_mag)
|
||||||
|
est_mag = mask_mag * mag_spec
|
||||||
|
est_phase = mask_phase * phase_spec
|
||||||
|
# cos(theta) + isin(theta)
|
||||||
|
real = est_mag + torch.cos(est_phase)
|
||||||
|
imag = est_mag + torch.sin(est_phase)
|
||||||
|
|
||||||
|
if self.masking_mode == "C":
|
||||||
|
|
||||||
|
real = real * mask_real - imag * mask_imag
|
||||||
|
imag = real * mask_imag + imag * mask_real
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
real = real * mask_real
|
||||||
|
imag = imag * mask_imag
|
||||||
|
|
||||||
|
spec = torch.cat([real, imag], 1)
|
||||||
|
wav = self.istft(spec)
|
||||||
|
wav = wav.clamp_(-1, 1)
|
||||||
|
return wav
|
||||||
|
|
@ -204,9 +204,9 @@ class Demucs(Model):
|
||||||
if waveform.dim() == 2:
|
if waveform.dim() == 2:
|
||||||
waveform = waveform.unsqueeze(1)
|
waveform = waveform.unsqueeze(1)
|
||||||
|
|
||||||
if waveform.size(1) != 1:
|
if waveform.size(1) != self.hparams.num_channels:
|
||||||
raise TypeError(
|
raise ValueError(
|
||||||
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels"
|
||||||
)
|
)
|
||||||
if self.normalize:
|
if self.normalize:
|
||||||
waveform = waveform.mean(dim=1, keepdim=True)
|
waveform = waveform.mean(dim=1, keepdim=True)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Text, Union
|
from typing import Any, List, Optional, Text, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -10,6 +10,7 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
|
from torch import nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
@ -36,7 +37,7 @@ class Model(pl.LightningModule):
|
||||||
Enhancer dataset used for training/validation
|
Enhancer dataset used for training/validation
|
||||||
duration: float, optional
|
duration: float, optional
|
||||||
duration used for training/inference
|
duration used for training/inference
|
||||||
loss : string or List of strings, default to "mse"
|
loss : string or List of strings or custom loss (nn.Module), default to "mse"
|
||||||
loss functions to be used. Available ("mse","mae","Si-SDR")
|
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List, Any] = "mse",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -86,10 +87,11 @@ class Model(pl.LightningModule):
|
||||||
@metric.setter
|
@metric.setter
|
||||||
def metric(self, metric):
|
def metric(self, metric):
|
||||||
self._metric = []
|
self._metric = []
|
||||||
if isinstance(metric, str):
|
if isinstance(metric, (str, nn.Module)):
|
||||||
metric = [metric]
|
metric = [metric]
|
||||||
|
|
||||||
for func in metric:
|
for func in metric:
|
||||||
|
if isinstance(func, str):
|
||||||
if func in LOSS_MAP.keys():
|
if func in LOSS_MAP.keys():
|
||||||
if func in ("pesq", "stoi"):
|
if func in ("pesq", "stoi"):
|
||||||
self._metric.append(
|
self._metric.append(
|
||||||
|
|
@ -97,9 +99,13 @@ class Model(pl.LightningModule):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._metric.append(LOSS_MAP[func]())
|
self._metric.append(LOSS_MAP[func]())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid metrics {func}")
|
ValueError(f"Invalid metrics {func}")
|
||||||
|
|
||||||
|
elif isinstance(func, nn.Module):
|
||||||
|
self._metric.append(func)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid metrics")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from scipy.signal import get_window
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFFT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_len: int,
|
||||||
|
nfft: Optional[int] = None,
|
||||||
|
window: str = "hamming",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.window_len = window_len
|
||||||
|
self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len)))
|
||||||
|
self.window = torch.from_numpy(
|
||||||
|
get_window(window, window_len, fftbins=True).astype("float32")
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_kernel(self, inverse=False):
|
||||||
|
|
||||||
|
fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len]
|
||||||
|
real, imag = np.real(fourier_basis), np.imag(fourier_basis)
|
||||||
|
kernel = np.concatenate([real, imag], 1).T
|
||||||
|
if inverse:
|
||||||
|
kernel = np.linalg.pinv(kernel).T
|
||||||
|
kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1)
|
||||||
|
kernel *= self.window
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
|
||||||
|
class ConvSTFT(ConvFFT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_len: int,
|
||||||
|
hop_size: Optional[int] = None,
|
||||||
|
nfft: Optional[int] = None,
|
||||||
|
window: str = "hamming",
|
||||||
|
):
|
||||||
|
super().__init__(window_len=window_len, nfft=nfft, window=window)
|
||||||
|
self.hop_size = hop_size if hop_size else window_len // 2
|
||||||
|
self.register_buffer("weight", self.init_kernel())
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
if input.dim() < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected signal with shape 2 or 3 got {input.dim()}"
|
||||||
|
)
|
||||||
|
elif input.dim() == 2:
|
||||||
|
input = input.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
input = F.pad(
|
||||||
|
input,
|
||||||
|
(self.window_len - self.hop_size, self.window_len - self.hop_size),
|
||||||
|
)
|
||||||
|
output = F.conv1d(input, self.weight, stride=self.hop_size)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ConviSTFT(ConvFFT):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_len: int,
|
||||||
|
hop_size: Optional[int] = None,
|
||||||
|
nfft: Optional[int] = None,
|
||||||
|
window: str = "hamming",
|
||||||
|
):
|
||||||
|
super().__init__(window_len=window_len, nfft=nfft, window=window)
|
||||||
|
self.hop_size = hop_size if hop_size else window_len // 2
|
||||||
|
self.register_buffer("weight", self.init_kernel(True))
|
||||||
|
self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1))
|
||||||
|
|
||||||
|
def forward(self, input, phase=None):
|
||||||
|
|
||||||
|
if phase is not None:
|
||||||
|
real = input * torch.cos(phase)
|
||||||
|
imag = input * torch.sin(phase)
|
||||||
|
input = torch.cat([real, imag], 1)
|
||||||
|
out = F.conv_transpose1d(input, self.weight, stride=self.hop_size)
|
||||||
|
coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2
|
||||||
|
coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size)
|
||||||
|
out = out / (coeff + 1e-8)
|
||||||
|
pad = self.window_len - self.hop_size
|
||||||
|
out = out[..., pad:-pad]
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d
|
||||||
|
from enhancer.models.complexnn.rnn import ComplexLSTM
|
||||||
|
from enhancer.models.complexnn.utils import ComplexBatchNorm2D
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexconv2d():
|
||||||
|
sample_input = torch.rand(1, 2, 256, 13)
|
||||||
|
conv = ComplexConv2d(
|
||||||
|
2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1)
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = conv(sample_input)
|
||||||
|
assert out.shape == torch.Size([1, 32, 128, 13])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexconvtranspose2d():
|
||||||
|
sample_input = torch.rand(1, 512, 4, 13)
|
||||||
|
conv = ComplexConvTranspose2d(
|
||||||
|
256 * 2,
|
||||||
|
128 * 2,
|
||||||
|
kernel_size=(5, 2),
|
||||||
|
stride=(2, 1),
|
||||||
|
padding=(2, 0),
|
||||||
|
output_padding=(1, 0),
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = conv(sample_input)
|
||||||
|
|
||||||
|
assert out.shape == torch.Size([1, 256, 8, 14])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexlstm():
|
||||||
|
sample_input = torch.rand(13, 2, 128)
|
||||||
|
lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = lstm(sample_input)
|
||||||
|
|
||||||
|
assert out[0].shape == torch.Size([13, 1, 512])
|
||||||
|
assert out[1].shape == torch.Size([13, 1, 512])
|
||||||
|
|
||||||
|
|
||||||
|
def test_complexbatchnorm2d():
|
||||||
|
sample_input = torch.rand(1, 64, 64, 14)
|
||||||
|
batchnorm = ComplexBatchNorm2D(num_features=64)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = batchnorm(sample_input)
|
||||||
|
|
||||||
|
assert out.size() == sample_input.size()
|
||||||
|
|
@ -30,7 +30,7 @@ def test_forward(batch_size, samples):
|
||||||
|
|
||||||
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(ValueError):
|
||||||
_ = model(data)
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
from enhancer.models.dccrn import DCCRN
|
||||||
|
from enhancer.utils.config import Files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vctk_dataset():
|
||||||
|
root_dir = "tests/data/vctk"
|
||||||
|
files = Files(
|
||||||
|
train_clean="clean_testset_wav",
|
||||||
|
train_noisy="noisy_testset_wav",
|
||||||
|
test_clean="clean_testset_wav",
|
||||||
|
test_noisy="noisy_testset_wav",
|
||||||
|
)
|
||||||
|
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
|
||||||
|
def test_forward(batch_size, samples):
|
||||||
|
model = DCCRN()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 1, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
data = torch.rand(batch_size, 2, samples, requires_grad=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model(data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dataset,channels,loss",
|
||||||
|
[(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
|
||||||
|
)
|
||||||
|
def test_demucs_init(dataset, channels, loss):
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = DCCRN(num_channels=channels, dataset=dataset, loss=loss)
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from enhancer.utils.transforms import ConviSTFT, ConvSTFT
|
||||||
|
|
||||||
|
|
||||||
|
def test_stft_istft():
|
||||||
|
sample_input = torch.rand(1, 1, 16000)
|
||||||
|
stft = ConvSTFT(window_len=400, hop_size=100, nfft=512)
|
||||||
|
istft = ConviSTFT(window_len=400, hop_size=100, nfft=512)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
spectrogram = stft(sample_input)
|
||||||
|
waveform = istft(spectrogram)
|
||||||
|
assert sample_input.shape == waveform.shape
|
||||||
|
assert (
|
||||||
|
torch.isclose(waveform, sample_input).sum().item()
|
||||||
|
> sample_input.shape[-1] // 2
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue