diff --git a/enhancer/cli/train_config/model/DCCRN.yaml b/enhancer/cli/train_config/model/DCCRN.yaml new file mode 100644 index 0000000..3190391 --- /dev/null +++ b/enhancer/cli/train_config/model/DCCRN.yaml @@ -0,0 +1,25 @@ +_target_: enhancer.models.dccrn.DCCRN +num_channels: 1 +sampling_rate : 16000 +complex_lstm : True +complex_norm : True +complex_relu : True +masking_mode : True + +encoder_decoder: + initial_output_channels : 32 + depth : 6 + kernel_size : 5 + growth_factor : 2 + stride : 2 + padding : 2 + output_padding : 1 + +lstm: + num_layers : 2 + hidden_size : 256 + +stft: + window_len : 400 + hop_size : 100 + nfft : 512 diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 2444c7f..1ef27f7 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule): name: str, root_dir: str, files: Files, - valid_minutes: float = 0.20, + min_valid_minutes: float = 0.20, duration: float = 1.0, stride=None, sampling_rate: int = 48000, @@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers - if valid_minutes > 0.0: - self.valid_minutes = valid_minutes + if min_valid_minutes > 0.0: + self.min_valid_minutes = min_valid_minutes else: - raise ValueError("valid_minutes must be greater than 0") + raise ValueError("min_valid_minutes must be greater than 0") self.augmentations = augmentations @@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule): ) train_data = fp.prepare_matching_dict() train_data, self.val_data = self.train_valid_split( - train_data, valid_minutes=self.valid_minutes, random_state=42 + train_data, + min_valid_minutes=self.min_valid_minutes, + random_state=42, ) self.train_data = self.prepare_traindata(train_data) @@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule): self._test = self.prepare_mapstype(test_data) def train_valid_split( - self, data, valid_minutes: float = 20, random_state: int = 42 + self, data, min_valid_minutes: float = 20, random_state: int = 42 ): - valid_minutes *= 60 + min_valid_minutes *= 60 valid_sec_now = 0.0 valid_indices = [] all_speakers = np.unique( @@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule): possible_indices = list(range(0, len(all_speakers))) rng = create_unique_rng(len(all_speakers)) - while valid_sec_now <= valid_minutes: + while valid_sec_now <= min_valid_minutes: speaker_index = rng.choice(possible_indices) possible_indices.remove(speaker_index) speaker_name = all_speakers[speaker_index] @@ -257,6 +259,9 @@ class EnhancerDataset(TaskDataset): files : Files dataclass containing train_clean, train_noisy, test_clean, test_noisy folder names (refer enhancer.utils.Files dataclass) + min_valid_minutes: float + minimum validation split size time in minutes + algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data. duration : float expected audio duration of single audio sample for training sampling_rate : int @@ -271,6 +276,7 @@ class EnhancerDataset(TaskDataset): use one_to_many mapping for multiple noisy files for each clean file + """ def __init__( @@ -278,7 +284,7 @@ class EnhancerDataset(TaskDataset): name: str, root_dir: str, files: Files, - valid_minutes=5.0, + min_valid_minutes=5.0, duration=1.0, stride=None, sampling_rate=48000, @@ -292,7 +298,7 @@ class EnhancerDataset(TaskDataset): name=name, root_dir=root_dir, files=files, - valid_minutes=valid_minutes, + min_valid_minutes=min_valid_minutes, sampling_rate=sampling_rate, duration=duration, matching_function=matching_function, diff --git a/enhancer/models/complexnn/__init__.py b/enhancer/models/complexnn/__init__.py new file mode 100644 index 0000000..918a261 --- /dev/null +++ b/enhancer/models/complexnn/__init__.py @@ -0,0 +1,5 @@ +from enhancer.models.complexnn.conv import ComplexConv2d # noqa +from enhancer.models.complexnn.conv import ComplexConvTranspose2d # noqa +from enhancer.models.complexnn.rnn import ComplexLSTM # noqa +from enhancer.models.complexnn.utils import ComplexBatchNorm2D # noqa +from enhancer.models.complexnn.utils import ComplexRelu # noqa diff --git a/enhancer/models/complexnn/conv.py b/enhancer/models/complexnn/conv.py new file mode 100644 index 0000000..d9a4d0f --- /dev/null +++ b/enhancer/models/complexnn/conv.py @@ -0,0 +1,136 @@ +from typing import Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +def init_weights(nnet): + nn.init.xavier_normal_(nnet.weight.data) + nn.init.constant_(nnet.bias, 0.0) + return nnet + + +class ComplexConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + groups: int = 1, + dilation: int = 1, + ): + """ + Complex Conv2d (non-causal) + """ + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.dilation = dilation + + self.real_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=(self.padding[0], 0), + groups=self.groups, + dilation=self.dilation, + ) + self.imag_conv = init_weights(self.imag_conv) + self.real_conv = init_weights(self.real_conv) + + def forward(self, input): + """ + complex axis should be always 1 dim + """ + input = F.pad(input, [self.padding[1], 0, 0, 0]) + + real, imag = torch.chunk(input, 2, 1) + + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + return out + + +class ComplexConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + output_padding: Tuple[int, int] = (0, 0), + groups: int = 1, + ): + super().__init__() + self.in_channels = in_channels // 2 + self.out_channels = out_channels // 2 + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.groups = groups + self.output_padding = output_padding + + self.real_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + self.imag_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + output_padding=self.output_padding, + groups=self.groups, + ) + + self.real_conv = init_weights(self.real_conv) + self.imag_conv = init_weights(self.imag_conv) + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + real_real = self.real_conv(real) + real_imag = self.imag_conv(real) + + imag_imag = self.imag_conv(imag) + imag_real = self.real_conv(imag) + + real = real_real - imag_imag + imag = real_imag - imag_real + + out = torch.cat([real, imag], 1) + + return out diff --git a/enhancer/models/complexnn/rnn.py b/enhancer/models/complexnn/rnn.py new file mode 100644 index 0000000..847030b --- /dev/null +++ b/enhancer/models/complexnn/rnn.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import torch +from torch import nn + + +class ComplexLSTM(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int = 1, + projection_size: Optional[int] = None, + bidirectional: bool = False, + ): + super().__init__() + self.input_size = input_size // 2 + self.hidden_size = hidden_size // 2 + self.num_layers = num_layers + + self.real_lstm = nn.LSTM( + self.input_size, + self.hidden_size, + self.num_layers, + bidirectional=bidirectional, + batch_first=False, + ) + self.imag_lstm = nn.LSTM( + self.input_size, + self.hidden_size, + self.num_layers, + bidirectional=bidirectional, + batch_first=False, + ) + + bidirectional = 2 if bidirectional else 1 + if projection_size is not None: + self.projection_size = projection_size // 2 + self.real_linear = nn.Linear( + self.hidden_size * bidirectional, self.projection_size + ) + self.imag_linear = nn.Linear( + self.hidden_size * bidirectional, self.projection_size + ) + else: + self.projection_size = None + + def forward(self, input): + + if isinstance(input, List): + real, imag = input + else: + real, imag = torch.chunk(input, 2, 1) + + real_real = self.real_lstm(real)[0] + real_imag = self.imag_lstm(real)[0] + + imag_imag = self.imag_lstm(imag)[0] + imag_real = self.real_lstm(imag)[0] + + real = real_real - imag_imag + imag = imag_real + real_imag + + if self.projection_size is not None: + real = self.real_linear(real) + imag = self.imag_linear(imag) + + return [real, imag] diff --git a/enhancer/models/complexnn/utils.py b/enhancer/models/complexnn/utils.py new file mode 100644 index 0000000..0c28f9b --- /dev/null +++ b/enhancer/models/complexnn/utils.py @@ -0,0 +1,199 @@ +import torch +from torch import nn + + +class ComplexBatchNorm2D(nn.Module): + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ): + """ + Complex batch normalization 2D + https://arxiv.org/abs/1705.09792 + + + """ + super().__init__() + self.num_features = num_features // 2 + self.affine = affine + self.momentum = momentum + self.track_running_stats = track_running_stats + self.eps = eps + + if self.affine: + self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features)) + else: + self.register_parameter("Wrr", None) + self.register_parameter("Wri", None) + self.register_parameter("Wii", None) + self.register_parameter("Br", None) + self.register_parameter("Bi", None) + + if self.track_running_stats: + values = torch.zeros(self.num_features) + self.register_buffer("Mean_real", values) + self.register_buffer("Mean_imag", values) + self.register_buffer("Var_rr", values) + self.register_buffer("Var_ri", values) + self.register_buffer("Var_ii", values) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("Mean_real", None) + self.register_parameter("Mean_imag", None) + self.register_parameter("Var_rr", None) + self.register_parameter("Var_ri", None) + self.register_parameter("Var_ii", None) + self.register_parameter("num_batches_tracked", None) + + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + self.Wrr.data.fill_(1) + self.Wii.data.fill_(1) + self.Wri.data.uniform_(-0.9, 0.9) + self.Br.data.fill_(0) + self.Bi.data.fill_(0) + self.reset_running_stats() + + def reset_running_stats(self): + if self.track_running_stats: + self.Mean_real.zero_() + self.Mean_imag.zero_() + self.Var_rr.fill_(1) + self.Var_ri.zero_() + self.Var_ii.fill_(1) + self.num_batches_tracked.zero_() + + def extra_repr(self): + return "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, track_running_stats={track_running_stats}".format( + **self.__dict__ + ) + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + exp_avg_factor = 0.0 + + training = self.training and self.track_running_stats + if training: + self.num_batches_tracked += 1 + if self.momentum is None: + exp_avg_factor = 1 / self.num_batches_tracked + else: + exp_avg_factor = self.momentum + + redux = [i for i in reversed(range(real.dim())) if i != 1] + vdim = [1] * real.dim() + vdim[1] = real.size(1) + + if training: + batch_mean_real, batch_mean_imag = real, imag + for dim in redux: + batch_mean_real = batch_mean_real.mean(dim, keepdim=True) + batch_mean_imag = batch_mean_imag.mean(dim, keepdim=True) + if self.track_running_stats: + self.Mean_real.lerp_(batch_mean_real.squeeze(), exp_avg_factor) + self.Mean_imag.lerp_(batch_mean_imag.squeeze(), exp_avg_factor) + + else: + batch_mean_real = self.Mean_real.view(vdim) + batch_mean_imag = self.Mean_imag.view(vdim) + + real = real - batch_mean_real + imag = imag - batch_mean_imag + + if training: + batch_var_rr = real * real + batch_var_ri = real * imag + batch_var_ii = imag * imag + for dim in redux: + batch_var_rr = batch_var_rr.mean(dim, keepdim=True) + batch_var_ri = batch_var_ri.mean(dim, keepdim=True) + batch_var_ii = batch_var_ii.mean(dim, keepdim=True) + if self.track_running_stats: + self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) + self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) + self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) + else: + batch_var_rr = self.Var_rr.view(vdim) + batch_var_ii = self.Var_ii.view(vdim) + batch_var_ri = self.Var_ri.view(vdim) + + batch_var_rr += self.eps + batch_var_ii += self.eps + + # Covariance matrics + # | batch_var_rr batch_var_ri | + # | batch_var_ir batch_var_ii | here batch_var_ir == batch_var_ri + # Inverse square root of cov matrix by combining below two formulas + # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix + # https://mathworld.wolfram.com/MatrixInverse.html + + tau = batch_var_rr + batch_var_ii + s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri + t = (tau + 2 * s).sqrt() + + rst = (s * t).reciprocal() + Urr = (batch_var_ii + s) * rst + Uri = -batch_var_ri * rst + Uii = (batch_var_rr + s) * rst + + if self.affine: + Wrr, Wri, Wii = ( + self.Wrr.view(vdim), + self.Wri.view(vdim), + self.Wii.view(vdim), + ) + Zrr = (Wrr * Urr) + (Wri * Uri) + Zri = (Wrr * Uri) + (Wri * Uii) + Zir = (Wii * Uri) + (Wri * Urr) + Zii = (Wri * Uri) + (Wii * Uii) + else: + Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii + + yr = (Zrr * real) + (Zri * imag) + yi = (Zir * real) + (Zii * imag) + + if self.affine: + yr = yr + self.Br.view(vdim) + yi = yi + self.Bi.view(vdim) + + outputs = torch.cat([yr, yi], 1) + return outputs + + +class ComplexRelu(nn.Module): + def __init__(self): + super().__init__() + self.real_relu = nn.PReLU() + self.imag_relu = nn.PReLU() + + def forward(self, input): + + real, imag = torch.chunk(input, 2, 1) + real = self.real_relu(real) + imag = self.imag_relu(imag) + return torch.cat([real, imag], dim=1) + + +def complex_cat(inputs, axis=1): + + real, imag = [], [] + for data in inputs: + real_data, imag_data = torch.chunk(data, 2, axis) + real.append(real_data) + imag.append(imag_data) + real = torch.cat(real, axis) + imag = torch.cat(imag, axis) + return torch.cat([real, imag], axis) diff --git a/enhancer/models/dccrn.py b/enhancer/models/dccrn.py new file mode 100644 index 0000000..7b1e5b1 --- /dev/null +++ b/enhancer/models/dccrn.py @@ -0,0 +1,338 @@ +import logging +from typing import Any, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from enhancer.data import EnhancerDataset +from enhancer.models import Model +from enhancer.models.complexnn import ( + ComplexBatchNorm2D, + ComplexConv2d, + ComplexConvTranspose2d, + ComplexLSTM, + ComplexRelu, +) +from enhancer.models.complexnn.utils import complex_cat +from enhancer.utils.transforms import ConviSTFT, ConvSTFT +from enhancer.utils.utils import merge_dict + + +class DCCRN_ENCODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channel: int, + kernel_size: Tuple[int, int], + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 1), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + self.encoder = nn.Sequential( + ComplexConv2d( + in_channels, + out_channel, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + batchnorm(out_channel), + activation, + ) + + def forward(self, waveform): + + return self.encoder(waveform) + + +class DCCRN_DECODER(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int], + layer: int = 0, + complex_norm: bool = True, + complex_relu: bool = True, + stride: Tuple[int, int] = (2, 1), + padding: Tuple[int, int] = (2, 0), + output_padding: Tuple[int, int] = (1, 0), + ): + super().__init__() + batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d + activation = ComplexRelu() if complex_relu else nn.PReLU() + + if layer != 0: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ), + batchnorm(out_channels), + activation, + ) + else: + self.decoder = nn.Sequential( + ComplexConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + ) + + def forward(self, waveform): + + return self.decoder(waveform) + + +class DCCRN(Model): + + STFT_DEFAULTS = { + "window_len": 400, + "hop_size": 100, + "nfft": 512, + "window": "hamming", + } + + ED_DEFAULTS = { + "initial_output_channels": 32, + "depth": 6, + "kernel_size": 5, + "growth_factor": 2, + "stride": 2, + "padding": 2, + "output_padding": 1, + } + + LSTM_DEFAULTS = { + "num_layers": 2, + "hidden_size": 256, + } + + def __init__( + self, + stft: Optional[dict] = None, + encoder_decoder: Optional[dict] = None, + lstm: Optional[dict] = None, + complex_lstm: bool = True, + complex_norm: bool = True, + complex_relu: bool = True, + masking_mode: str = "E", + num_channels: int = 1, + sampling_rate=16000, + lr: float = 1e-3, + dataset: Optional[EnhancerDataset] = None, + duration: Optional[float] = None, + loss: Union[str, List, Any] = "mse", + metric: Union[str, List] = "mse", + ): + duration = ( + dataset.duration if isinstance(dataset, EnhancerDataset) else None + ) + if dataset is not None: + if sampling_rate != dataset.sampling_rate: + logging.warning( + f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" + ) + sampling_rate = dataset.sampling_rate + super().__init__( + num_channels=num_channels, + sampling_rate=sampling_rate, + lr=lr, + dataset=dataset, + duration=duration, + loss=loss, + metric=metric, + ) + + encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) + lstm = merge_dict(self.LSTM_DEFAULTS, lstm) + stft = merge_dict(self.STFT_DEFAULTS, stft) + self.save_hyperparameters( + "encoder_decoder", + "lstm", + "stft", + "complex_lstm", + "complex_norm", + "masking_mode", + ) + self.complex_lstm = complex_lstm + self.complex_norm = complex_norm + self.masking_mode = masking_mode + + self.stft = ConvSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + self.istft = ConviSTFT( + stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] + ) + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + num_channels *= 2 + hidden_size = encoder_decoder["initial_output_channels"] + growth_factor = 2 + + for layer in range(encoder_decoder["depth"]): + + encoder_ = DCCRN_ENCODER( + num_channels, + hidden_size, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 1), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + self.encoder.append(encoder_) + + decoder_ = DCCRN_DECODER( + hidden_size + hidden_size, + num_channels, + layer=layer, + kernel_size=(encoder_decoder["kernel_size"], 2), + stride=(encoder_decoder["stride"], 1), + padding=(encoder_decoder["padding"], 0), + output_padding=(encoder_decoder["output_padding"], 0), + complex_norm=complex_norm, + complex_relu=complex_relu, + ) + + self.decoder.insert(0, decoder_) + + if layer < encoder_decoder["depth"] - 3: + num_channels = hidden_size + hidden_size *= growth_factor + else: + num_channels = hidden_size + + kernel_size = hidden_size / 2 + hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) + + if self.complex_lstm: + lstms = [] + for layer in range(lstm["num_layers"]): + + if layer == 0: + input_size = int(hidden_size * kernel_size) + else: + input_size = lstm["hidden_size"] + + if layer == lstm["num_layers"] - 1: + projection_size = int(hidden_size * kernel_size) + else: + projection_size = None + + kwargs = { + "input_size": input_size, + "hidden_size": lstm["hidden_size"], + "num_layers": 1, + } + + lstms.append( + ComplexLSTM(projection_size=projection_size, **kwargs) + ) + self.lstm = nn.Sequential(*lstms) + else: + self.lstm = nn.Sequential( + nn.LSTM( + input_size=hidden_size * kernel_size, + hidden_sizs=lstm["hidden_size"], + num_layers=lstm["num_layers"], + dropout=0.0, + batch_first=False, + )[0], + nn.Linear(lstm["hidden"], hidden_size * kernel_size), + ) + + def forward(self, waveform): + + if waveform.dim() == 2: + waveform = waveform.unsqueeze(1) + + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" + ) + + waveform_stft = self.stft(waveform) + real = waveform_stft[:, : self.stft.nfft // 2 + 1] + imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] + + mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9) + phase_spec = torch.atan2(imag, real) + complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:] + + encoder_outputs = [] + out = complex_spec + for _, encoder in enumerate(self.encoder): + out = encoder(out) + encoder_outputs.append(out) + + B, C, D, T = out.size() + out = out.permute(3, 0, 1, 2) + if self.complex_lstm: + + lstm_real = out[:, :, : C // 2] + lstm_imag = out[:, :, C // 2 :] + lstm_real = lstm_real.reshape(T, B, C // 2 * D) + lstm_imag = lstm_imag.reshape(T, B, C // 2 * D) + lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag]) + lstm_real = lstm_real.reshape(T, B, C // 2, D) + lstm_imag = lstm_imag.reshape(T, B, C // 2, D) + out = torch.cat([lstm_real, lstm_imag], 2) + else: + out = out.reshape(T, B, C * D) + out = self.lstm(out) + out = out.reshape(T, B, D, C) + + out = out.permute(1, 2, 3, 0) + for layer, decoder in enumerate(self.decoder): + skip_connection = encoder_outputs.pop(-1) + out = complex_cat([skip_connection, out]) + out = decoder(out) + out = out[..., 1:] + mask_real, mask_imag = out[:, 0], out[:, 1] + mask_real = F.pad(mask_real, [0, 0, 1, 0]) + mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) + if self.masking_mode == "E": + + mask_mag = torch.sqrt(mask_real**2 + mask_imag**2) + real_phase = mask_real / (mask_mag + 1e-8) + imag_phase = mask_imag / (mask_mag + 1e-8) + mask_phase = torch.atan2(imag_phase, real_phase) + mask_mag = torch.tanh(mask_mag) + est_mag = mask_mag * mag_spec + est_phase = mask_phase * phase_spec + # cos(theta) + isin(theta) + real = est_mag + torch.cos(est_phase) + imag = est_mag + torch.sin(est_phase) + + if self.masking_mode == "C": + + real = real * mask_real - imag * mask_imag + imag = real * mask_imag + imag * mask_real + + else: + + real = real * mask_real + imag = imag * mask_imag + + spec = torch.cat([real, imag], 1) + wav = self.istft(spec) + wav = wav.clamp_(-1, 1) + return wav diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index e5fa945..fafb84e 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -204,9 +204,9 @@ class Demucs(Model): if waveform.dim() == 2: waveform = waveform.unsqueeze(1) - if waveform.size(1) != 1: - raise TypeError( - f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" + if waveform.size(1) != self.hparams.num_channels: + raise ValueError( + f"Number of input channels initialized is {self.hparams.num_channels} but got {waveform.size(1)} channels" ) if self.normalize: waveform = waveform.mean(dim=1, keepdim=True) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 3b60b85..c679669 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -2,7 +2,7 @@ import os from collections import defaultdict from importlib import import_module from pathlib import Path -from typing import List, Optional, Text, Union +from typing import Any, List, Optional, Text, Union from urllib.parse import urlparse import numpy as np @@ -10,6 +10,7 @@ import pytorch_lightning as pl import torch from huggingface_hub import cached_download, hf_hub_url from pytorch_lightning.utilities.cloud_io import load as pl_load +from torch import nn from torch.optim import Adam from enhancer.data.dataset import EnhancerDataset @@ -36,7 +37,7 @@ class Model(pl.LightningModule): Enhancer dataset used for training/validation duration: float, optional duration used for training/inference - loss : string or List of strings, default to "mse" + loss : string or List of strings or custom loss (nn.Module), default to "mse" loss functions to be used. Available ("mse","mae","Si-SDR") """ @@ -49,7 +50,7 @@ class Model(pl.LightningModule): dataset: Optional[EnhancerDataset] = None, duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric: Union[str, List] = "mse", + metric: Union[str, List, Any] = "mse", ): super().__init__() assert ( @@ -86,20 +87,25 @@ class Model(pl.LightningModule): @metric.setter def metric(self, metric): self._metric = [] - if isinstance(metric, str): + if isinstance(metric, (str, nn.Module)): metric = [metric] for func in metric: - if func in LOSS_MAP.keys(): - if func in ("pesq", "stoi"): - self._metric.append( - LOSS_MAP[func](self.hparams.sampling_rate) - ) + if isinstance(func, str): + if func in LOSS_MAP.keys(): + if func in ("pesq", "stoi"): + self._metric.append( + LOSS_MAP[func](self.hparams.sampling_rate) + ) + else: + self._metric.append(LOSS_MAP[func]()) else: - self._metric.append(LOSS_MAP[func]()) + ValueError(f"Invalid metrics {func}") + elif isinstance(func, nn.Module): + self._metric.append(func) else: - raise ValueError(f"Invalid metrics {func}") + raise ValueError("Invalid metrics") @property def dataset(self): diff --git a/enhancer/utils/transforms.py b/enhancer/utils/transforms.py new file mode 100644 index 0000000..fbdb8f9 --- /dev/null +++ b/enhancer/utils/transforms.py @@ -0,0 +1,92 @@ +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.signal import get_window +from torch import nn + + +class ConvFFT(nn.Module): + def __init__( + self, + window_len: int, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super().__init__() + self.window_len = window_len + self.nfft = nfft if nfft else np.int(2 ** np.ceil(np.log2(window_len))) + self.window = torch.from_numpy( + get_window(window, window_len, fftbins=True).astype("float32") + ) + + def init_kernel(self, inverse=False): + + fourier_basis = np.fft.rfft(np.eye(self.nfft))[: self.window_len] + real, imag = np.real(fourier_basis), np.imag(fourier_basis) + kernel = np.concatenate([real, imag], 1).T + if inverse: + kernel = np.linalg.pinv(kernel).T + kernel = torch.from_numpy(kernel.astype("float32")).unsqueeze(1) + kernel *= self.window + return kernel + + +class ConvSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super().__init__(window_len=window_len, nfft=nfft, window=window) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel()) + + def forward(self, input): + + if input.dim() < 2: + raise ValueError( + f"Expected signal with shape 2 or 3 got {input.dim()}" + ) + elif input.dim() == 2: + input = input.unsqueeze(1) + else: + pass + input = F.pad( + input, + (self.window_len - self.hop_size, self.window_len - self.hop_size), + ) + output = F.conv1d(input, self.weight, stride=self.hop_size) + + return output + + +class ConviSTFT(ConvFFT): + def __init__( + self, + window_len: int, + hop_size: Optional[int] = None, + nfft: Optional[int] = None, + window: str = "hamming", + ): + super().__init__(window_len=window_len, nfft=nfft, window=window) + self.hop_size = hop_size if hop_size else window_len // 2 + self.register_buffer("weight", self.init_kernel(True)) + self.register_buffer("enframe", torch.eye(window_len).unsqueeze(1)) + + def forward(self, input, phase=None): + + if phase is not None: + real = input * torch.cos(phase) + imag = input * torch.sin(phase) + input = torch.cat([real, imag], 1) + out = F.conv_transpose1d(input, self.weight, stride=self.hop_size) + coeff = self.window.unsqueeze(1).repeat(1, 1, input.size(-1)) ** 2 + coeff = F.conv_transpose1d(coeff, self.enframe, stride=self.hop_size) + out = out / (coeff + 1e-8) + pad = self.window_len - self.hop_size + out = out[..., pad:-pad] + return out diff --git a/tests/models/complexnn_test.py b/tests/models/complexnn_test.py new file mode 100644 index 0000000..524a6cf --- /dev/null +++ b/tests/models/complexnn_test.py @@ -0,0 +1,50 @@ +import torch + +from enhancer.models.complexnn.conv import ComplexConv2d, ComplexConvTranspose2d +from enhancer.models.complexnn.rnn import ComplexLSTM +from enhancer.models.complexnn.utils import ComplexBatchNorm2D + + +def test_complexconv2d(): + sample_input = torch.rand(1, 2, 256, 13) + conv = ComplexConv2d( + 2, 32, kernel_size=(5, 2), stride=(2, 1), padding=(2, 1) + ) + with torch.no_grad(): + out = conv(sample_input) + assert out.shape == torch.Size([1, 32, 128, 13]) + + +def test_complexconvtranspose2d(): + sample_input = torch.rand(1, 512, 4, 13) + conv = ComplexConvTranspose2d( + 256 * 2, + 128 * 2, + kernel_size=(5, 2), + stride=(2, 1), + padding=(2, 0), + output_padding=(1, 0), + ) + with torch.no_grad(): + out = conv(sample_input) + + assert out.shape == torch.Size([1, 256, 8, 14]) + + +def test_complexlstm(): + sample_input = torch.rand(13, 2, 128) + lstm = ComplexLSTM(128 * 2, 128 * 2, projection_size=512 * 2) + with torch.no_grad(): + out = lstm(sample_input) + + assert out[0].shape == torch.Size([13, 1, 512]) + assert out[1].shape == torch.Size([13, 1, 512]) + + +def test_complexbatchnorm2d(): + sample_input = torch.rand(1, 64, 64, 14) + batchnorm = ComplexBatchNorm2D(num_features=64) + with torch.no_grad(): + out = batchnorm(sample_input) + + assert out.size() == sample_input.size() diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index f5a0ec4..29e030e 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -30,7 +30,7 @@ def test_forward(batch_size, samples): data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): _ = model(data) diff --git a/tests/models/test_dccrn.py b/tests/models/test_dccrn.py new file mode 100644 index 0000000..96a853b --- /dev/null +++ b/tests/models/test_dccrn.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from enhancer.data.dataset import EnhancerDataset +from enhancer.models.dccrn import DCCRN +from enhancer.utils.config import Files + + +@pytest.fixture +def vctk_dataset(): + root_dir = "tests/data/vctk" + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) + return dataset + + +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): + model = DCCRN() + model.eval() + + data = torch.rand(batch_size, 1, samples, requires_grad=False) + with torch.no_grad(): + _ = model(data) + + data = torch.rand(batch_size, 2, samples, requires_grad=False) + with torch.no_grad(): + with pytest.raises(ValueError): + _ = model(data) + + +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): + with torch.no_grad(): + _ = DCCRN(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/transforms_test.py b/tests/transforms_test.py new file mode 100644 index 0000000..89425ad --- /dev/null +++ b/tests/transforms_test.py @@ -0,0 +1,18 @@ +import torch + +from enhancer.utils.transforms import ConviSTFT, ConvSTFT + + +def test_stft_istft(): + sample_input = torch.rand(1, 1, 16000) + stft = ConvSTFT(window_len=400, hop_size=100, nfft=512) + istft = ConviSTFT(window_len=400, hop_size=100, nfft=512) + + with torch.no_grad(): + spectrogram = stft(sample_input) + waveform = istft(spectrogram) + assert sample_input.shape == waveform.shape + assert ( + torch.isclose(waveform, sample_input).sum().item() + > sample_input.shape[-1] // 2 + )