From 415ed8e3d0733bccf1d94b25100dfb03ba4fc4ad Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 18 Oct 2022 15:22:34 +0530 Subject: [PATCH 1/2] normalize input --- enhancer/models/demucs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 95d6a6f..86afb6c 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -133,10 +133,12 @@ class Demucs(Model): num_channels: int = 1, resample: int = 4, sampling_rate=16000, + normalize=True, lr: float = 1e-3, dataset: Optional[EnhancerDataset] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", + floor=1e-3, ): duration = ( dataset.duration if isinstance(dataset, EnhancerDataset) else None @@ -161,6 +163,8 @@ class Demucs(Model): lstm = merge_dict(self.LSTM_DEFAULTS, lstm) self.save_hyperparameters("encoder_decoder", "lstm", "resample") hidden = encoder_decoder["initial_output_channels"] + self.normalize = normalize + self.floor = floor self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() @@ -204,7 +208,10 @@ class Demucs(Model): raise TypeError( f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" ) - + if self.normalize: + waveform = waveform.mean(dim=1, keepdim=True) + std = waveform.std(dim=-1, keepdim=True) + waveform = waveform / (self.floor + std) length = waveform.shape[-1] x = F.pad(waveform, (0, self.get_padding_length(length) - length)) if self.hparams.resample > 1: From edb7f020f7ca8fa05d40087db73c2b556e400c77 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 18 Oct 2022 15:23:07 +0530 Subject: [PATCH 2/2] stride waveform --- enhancer/data/dataset.py | 44 ++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 02a1d3b..28f19a6 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -132,11 +132,14 @@ class TaskDataset(pl.LightningDataModule): for item in data: clean, noisy, total_dur = item.values() if total_dur < self.duration: - continue - num_segments = round(total_dur / self.duration) - for index in range(num_segments): - start_time = index * self.duration - metadata.append(({"clean": clean, "noisy": noisy}, start_time)) + metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) + else: + num_segments = round(total_dur / self.duration) + for index in range(num_segments): + start_time = index * self.duration + metadata.append( + ({"clean": clean, "noisy": noisy}, start_time) + ) return metadata def train_dataloader(self): @@ -195,6 +198,7 @@ class EnhancerDataset(TaskDataset): files: Files, valid_minutes=5.0, duration=1.0, + stride=0.5, sampling_rate=48000, matching_function=None, batch_size=32, @@ -217,6 +221,7 @@ class EnhancerDataset(TaskDataset): self.files = files self.duration = max(1.0, duration) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) + self.stride = stride or duration def setup(self, stage: Optional[str] = None): @@ -234,9 +239,22 @@ class EnhancerDataset(TaskDataset): weights=[file["duration"] for file in self.train_data], ) file_duration = file_dict["duration"] - start_time = round(rng.uniform(0, file_duration - self.duration), 2) - data = self.prepare_segment(file_dict, start_time) - yield data + num_segments = self.get_num_segments( + file_duration, self.duration, self.stride + ) + for index in range(0, num_segments): + start_time = index * self.stride + yield self.prepare_segment(file_dict, start_time) + + @staticmethod + def get_num_segments(file_duration, duration, stride): + + if file_duration < duration: + num_segments = 1 + else: + num_segments = math.ceil((file_duration - duration) / stride) + 1 + + return num_segments def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) @@ -273,8 +291,16 @@ class EnhancerDataset(TaskDataset): return {"clean": clean_segment, "noisy": noisy_segment} def train__len__(self): + return math.ceil( - sum([file["duration"] for file in self.train_data]) / self.duration + sum( + [ + self.get_num_segments( + file["duration"], self.duration, self.stride + ) + for file in self.train_data + ] + ) ) def val__len__(self):