From 2ad49faa677a66d49406fdc108cfe9d1fa46dd29 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 20 Oct 2022 09:49:27 +0530 Subject: [PATCH 01/21] debug iterative dataset --- enhancer/data/dataset.py | 102 ++++++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 28 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 28f19a6..8110a2a 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,9 +1,11 @@ import math import multiprocessing import os +from itertools import chain, cycle from typing import Optional import pytorch_lightning as pl +import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -55,6 +57,7 @@ class TaskDataset(pl.LightningDataModule): files: Files, valid_minutes: float = 0.20, duration: float = 1.0, + stride=None, sampling_rate: int = 48000, matching_function=None, batch_size=32, @@ -65,6 +68,7 @@ class TaskDataset(pl.LightningDataModule): self.name = name self.files, self.root_dir = check_files(root_dir, files) self.duration = duration + self.stride = stride or duration self.sampling_rate = sampling_rate self.batch_size = batch_size self.matching_function = matching_function @@ -90,10 +94,11 @@ class TaskDataset(pl.LightningDataModule): self.name, train_clean, train_noisy, self.matching_function ) train_data = fp.prepare_matching_dict() - self.train_data, self.val_data = self.train_valid_split( + train_data, self.val_data = self.train_valid_split( train_data, valid_minutes=self.valid_minutes, random_state=42 ) + self.train_data = self.prepare_traindata(train_data) self._validation = self.prepare_mapstype(self.val_data) test_clean = os.path.join(self.root_dir, self.files.test_clean) @@ -112,7 +117,7 @@ class TaskDataset(pl.LightningDataModule): valid_min_now = 0.0 valid_indices = [] random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state) + rng = create_unique_rng(random_state, 0) rng.shuffle(random_indices) i = 0 while valid_min_now <= valid_minutes: @@ -126,6 +131,33 @@ class TaskDataset(pl.LightningDataModule): valid_data = [item for i, item in enumerate(data) if i in valid_indices] return train_data, valid_data + def prepare_traindata(self, data): + train_data = [] + for item in data: + samples_metadata = [] + clean, noisy, total_dur = item.values() + num_segments = self.get_num_segments( + total_dur, self.duration, self.stride + ) + for index in range(num_segments): + start = index * self.stride + samples_metadata.append( + ({"clean": clean, "noisy": noisy}, start) + ) + train_data.append(samples_metadata) + print(train_data[:10]) + return train_data + + @staticmethod + def get_num_segments(file_duration, duration, stride): + + if file_duration < duration: + num_segments = 1 + else: + num_segments = math.ceil((file_duration - duration) / stride) + 1 + + return num_segments + def prepare_mapstype(self, data): metadata = [] @@ -142,11 +174,33 @@ class TaskDataset(pl.LightningDataModule): ) return metadata + def train_collatefn(self, batch): + + output = {"noisy": [], "clean": []} + for item in batch: + output["noisy"].append(item["noisy"]) + output["clean"].append(item["clean"]) + + output["clean"] = torch.stack(output["clean"], dim=0) + output["noisy"] = torch.stack(output["noisy"], dim=0) + return output + + def worker_init_fn(self): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + worker_id = worker_info.id + split_size = len(dataset.data) // worker_info.num_workers + dataset.data = dataset.data[ + worker_id * split_size : (worker_id + 1) * split_size + ] + def train_dataloader(self): return DataLoader( TrainDataset(self), - batch_size=self.batch_size, + batch_size=None, num_workers=self.num_workers, + collate_fn=self.train_collatefn, + worker_init_fn=self.worker_init_fn, ) def val_dataloader(self): @@ -227,24 +281,24 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) + def random_sample(self, index): + rng = create_unique_rng(self.model.current_epoch, index) + return rng.sample(self.train_data, len(self.train_data)) + def train__iter__(self): + return zip( + *[ + self.get_stream(self.random_sample(i)) + for i in range(self.batch_size) + ] + ) - rng = create_unique_rng(self.model.current_epoch) + def get_stream(self, data): + return chain.from_iterable(map(self.process_data, cycle(data))) - while True: - - file_dict, *_ = rng.choices( - self.train_data, - k=1, - weights=[file["duration"] for file in self.train_data], - ) - file_duration = file_dict["duration"] - num_segments = self.get_num_segments( - file_duration, self.duration, self.stride - ) - for index in range(0, num_segments): - start_time = index * self.stride - yield self.prepare_segment(file_dict, start_time) + def process_data(self, data): + for item in data: + yield self.prepare_segment(*item) @staticmethod def get_num_segments(file_duration, duration, stride): @@ -264,6 +318,7 @@ class EnhancerDataset(TaskDataset): def prepare_segment(self, file_dict: dict, start_time: float): + print(file_dict["clean"].split("/")[-1], "->", start_time) clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration ) @@ -292,16 +347,7 @@ class EnhancerDataset(TaskDataset): def train__len__(self): - return math.ceil( - sum( - [ - self.get_num_segments( - file["duration"], self.duration, self.stride - ) - for file in self.train_data - ] - ) - ) + return sum([len(item) for item in self.train_data]) def val__len__(self): return len(self._validation) From a6a2e4a4ae2d607f9b168ae0f2eda6885350f209 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 20 Oct 2022 09:50:04 +0530 Subject: [PATCH 02/21] add batch info --- enhancer/utils/random.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index dd9395a..2feb581 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -4,7 +4,7 @@ import random import torch -def create_unique_rng(epoch: int): +def create_unique_rng(epoch: int, index: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() @@ -29,6 +29,7 @@ def create_unique_rng(epoch: int): + local_rank * num_workers + node_rank * num_workers * global_rank + epoch * num_workers * world_size + + index ) rng.seed(seed) From c5824cb34a88d5871a4237f877448e3da0f3badc Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 20 Oct 2022 09:53:06 +0530 Subject: [PATCH 03/21] gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9cd222c..cd1b1e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ #local *.ckpt +*_local.yaml cli/train_config/dataset/Vctk_local.yaml .DS_Store outputs/ From f2561d7cf7f6cd1069310ddbdc4e0074da1e9b0d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 20 Oct 2022 09:53:27 +0530 Subject: [PATCH 04/21] config --- enhancer/cli/train_config/config.yaml | 2 +- enhancer/cli/train_config/dataset/Vctk.yaml | 5 +++-- enhancer/cli/train_config/trainer/default.yaml | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/enhancer/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml index c0b2cf6..8d0ab14 100644 --- a/enhancer/cli/train_config/config.yaml +++ b/enhancer/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : WaveUnet + - model : Demucs - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index df50da2..c128404 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -1,9 +1,10 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 -duration : 1.0 +duration : 4.5 +stride : 0.5 sampling_rate: 16000 -batch_size: 128 +batch_size: 16 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index dfc020f..01914e4 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -9,7 +9,7 @@ benchmark: False check_val_every_n_epoch: 1 detect_anomaly: False deterministic: False -devices: -1 +devices: 1 enable_checkpointing: True enable_model_summary: True enable_progress_bar: True From ba10719520d6fbb906d84ac4a3e7374b3279b893 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 20 Oct 2022 21:03:38 +0530 Subject: [PATCH 05/21] add arg --- enhancer/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 8110a2a..06a6f67 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -185,7 +185,7 @@ class TaskDataset(pl.LightningDataModule): output["noisy"] = torch.stack(output["noisy"], dim=0) return output - def worker_init_fn(self): + def worker_init_fn(self, _): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset worker_id = worker_info.id From 178a4523ef06e272128e431d9d8adad1fa1c7592 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 09:48:28 +0530 Subject: [PATCH 06/21] fix worker init fn --- enhancer/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 06a6f67..218cd7c 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -189,8 +189,8 @@ class TaskDataset(pl.LightningDataModule): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset worker_id = worker_info.id - split_size = len(dataset.data) // worker_info.num_workers - dataset.data = dataset.data[ + split_size = len(dataset.dataset.train_data) // worker_info.num_workers + dataset.data = dataset.dataset.train_data[ worker_id * split_size : (worker_id + 1) * split_size ] From 0d3bfd341210f2f43d9162e1be68a5e389f1de2c Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 11:13:17 +0530 Subject: [PATCH 07/21] debug --- enhancer/cli/train_config/dataset/Vctk.yaml | 3 ++- enhancer/data/dataset.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index c128404..2ea4018 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -4,7 +4,8 @@ root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 4.5 stride : 0.5 sampling_rate: 16000 -batch_size: 16 +batch_size: 4 +valid_minutes : 1 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 218cd7c..3722763 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -145,8 +145,7 @@ class TaskDataset(pl.LightningDataModule): ({"clean": clean, "noisy": noisy}, start) ) train_data.append(samples_metadata) - print(train_data[:10]) - return train_data + return train_data[:25] @staticmethod def get_num_segments(file_duration, duration, stride): @@ -175,12 +174,14 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - + names = [] output = {"noisy": [], "clean": []} for item in batch: output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) + names.append(item["name"]) + print(names) output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) return output @@ -318,7 +319,6 @@ class EnhancerDataset(TaskDataset): def prepare_segment(self, file_dict: dict, start_time: float): - print(file_dict["clean"].split("/")[-1], "->", start_time) clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration ) @@ -343,7 +343,11 @@ class EnhancerDataset(TaskDataset): ), ), ) - return {"clean": clean_segment, "noisy": noisy_segment} + return { + "clean": clean_segment, + "noisy": noisy_segment, + "name": file_dict["clean"].split("/")[-1] + "->" + start_time, + } def train__len__(self): From 5d7ea582c9664610c7a753500f4ff74a80697dbf Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 11:18:24 +0530 Subject: [PATCH 08/21] debug --- enhancer/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 3722763..9b512ab 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -346,7 +346,7 @@ class EnhancerDataset(TaskDataset): return { "clean": clean_segment, "noisy": noisy_segment, - "name": file_dict["clean"].split("/")[-1] + "->" + start_time, + "name": file_dict["clean"].split("/")[-1] + "->" + str(start_time), } def train__len__(self): From 9c7a650130aa9ba896846eb35ff05942ecda1395 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 11:37:26 +0530 Subject: [PATCH 09/21] div by batchsize in __len__ --- enhancer/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 9b512ab..cc89083 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -351,7 +351,7 @@ class EnhancerDataset(TaskDataset): def train__len__(self): - return sum([len(item) for item in self.train_data]) + return sum([len(item) for item in self.train_data]) // self.batch_size def val__len__(self): return len(self._validation) From 20c12556fff9e577045edecd700db1ef33ffe2a2 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 16:25:24 +0530 Subject: [PATCH 10/21] debug --- enhancer/data/dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index cc89083..fb4d04c 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -100,6 +100,9 @@ class TaskDataset(pl.LightningDataModule): self.train_data = self.prepare_traindata(train_data) self._validation = self.prepare_mapstype(self.val_data) + print( + "train_data_size", sum([len(item) for item in self.train_data]) + ) test_clean = os.path.join(self.root_dir, self.files.test_clean) test_noisy = os.path.join(self.root_dir, self.files.test_noisy) From a7fb27bb0f94b9e7674c1a97005b1c8202fe86f9 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 17:17:02 +0530 Subject: [PATCH 11/21] debug --- enhancer/data/dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index fb4d04c..a8b3896 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -353,8 +353,11 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - - return sum([len(item) for item in self.train_data]) // self.batch_size + worker_info = torch.utils.data.get_worker_info() + num_workers = worker_info.num_workers if worker_info else 1 + return sum([len(item) for item in self.train_data]) // ( + self.batch_size * num_workers + ) def val__len__(self): return len(self._validation) From a75f3c32a35e9b8192f0efce4ee38bc7553ba80e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 19:23:59 +0530 Subject: [PATCH 12/21] num_workers --- enhancer/cli/train_config/dataset/Vctk.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 2ea4018..0db0e7e 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -6,7 +6,7 @@ stride : 0.5 sampling_rate: 16000 batch_size: 4 valid_minutes : 1 - +num_workers: 0 files: train_clean : clean_trainset_28spk_wav test_clean : clean_testset_wav From cd9ffc1a684fe255b3c8a221145e68de9cb410b7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 23:22:56 +0530 Subject: [PATCH 13/21] fix randomization --- enhancer/data/dataset.py | 18 +++++++++++------- enhancer/utils/random.py | 3 +-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index a8b3896..d4686b5 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,6 +1,7 @@ import math import multiprocessing import os +import random from itertools import chain, cycle from typing import Optional @@ -120,7 +121,7 @@ class TaskDataset(pl.LightningDataModule): valid_min_now = 0.0 valid_indices = [] random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state, 0) + rng = create_unique_rng(random_state) rng.shuffle(random_indices) i = 0 while valid_min_now <= valid_minutes: @@ -285,14 +286,15 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) - def random_sample(self, index): - rng = create_unique_rng(self.model.current_epoch, index) - return rng.sample(self.train_data, len(self.train_data)) + def random_sample(self, train_data): + return random.sample(train_data, len(train_data)) def train__iter__(self): + rng = create_unique_rng(self.model.current_epoch) + train_data = rng.sample(self.train_data, len(self.train_data)) return zip( *[ - self.get_stream(self.random_sample(i)) + self.get_stream(self.random_sample(train_data)) for i in range(self.batch_size) ] ) @@ -353,8 +355,10 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - worker_info = torch.utils.data.get_worker_info() - num_workers = worker_info.num_workers if worker_info else 1 + if self.num_workers > 1: + num_workers = 2 + else: + num_workers = 1 return sum([len(item) for item in self.train_data]) // ( self.batch_size * num_workers ) diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 2feb581..dd9395a 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -4,7 +4,7 @@ import random import torch -def create_unique_rng(epoch: int, index: int): +def create_unique_rng(epoch: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() @@ -29,7 +29,6 @@ def create_unique_rng(epoch: int, index: int): + local_rank * num_workers + node_rank * num_workers * global_rank + epoch * num_workers * world_size - + index ) rng.seed(seed) From 8457e1cbe2cda4299bdd5fe56058fab2d07a2247 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 23:23:37 +0530 Subject: [PATCH 14/21] debug num_workers --- enhancer/cli/train_config/dataset/Vctk.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 0db0e7e..0e1f38f 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -6,7 +6,6 @@ stride : 0.5 sampling_rate: 16000 batch_size: 4 valid_minutes : 1 -num_workers: 0 files: train_clean : clean_trainset_28spk_wav test_clean : clean_testset_wav From c4a27686daa6bca8ad905f33ef177f193c45d4e4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 09:57:27 +0530 Subject: [PATCH 15/21] debug --- enhancer/data/dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index d4686b5..1ae48ed 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -77,6 +77,7 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers + print("num_workers-main", self.num_workers) if valid_minutes > 0.0: self.valid_minutes = valid_minutes else: @@ -184,7 +185,6 @@ class TaskDataset(pl.LightningDataModule): output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) names.append(item["name"]) - print(names) output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) @@ -359,8 +359,9 @@ class EnhancerDataset(TaskDataset): num_workers = 2 else: num_workers = 1 + print("num_workers", num_workers) return sum([len(item) for item in self.train_data]) // ( - self.batch_size * num_workers + self.batch_size * self.num_workers ) def val__len__(self): From 7fa54fc414f438a1cae888e93a31edb1f604c01e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 10:30:27 +0530 Subject: [PATCH 16/21] debug --- enhancer/data/dataset.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 1ae48ed..05dd287 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -355,14 +355,12 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - if self.num_workers > 1: - num_workers = 2 + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + train_data = self.train_data else: - num_workers = 1 - print("num_workers", num_workers) - return sum([len(item) for item in self.train_data]) // ( - self.batch_size * self.num_workers - ) + train_data = worker_info.dataset.data + return sum([len(item) for item in train_data]) // (self.batch_size) def val__len__(self): return len(self._validation) From 6314d210c35f89fe61698d13213a613270628d3e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 11:05:19 +0530 Subject: [PATCH 17/21] debug git commit -m debug ' --- enhancer/data/dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 05dd287..c808c8e 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -360,7 +360,9 @@ class EnhancerDataset(TaskDataset): train_data = self.train_data else: train_data = worker_info.dataset.data - return sum([len(item) for item in train_data]) // (self.batch_size) + len = sum([len(item) for item in train_data]) // (self.batch_size) + print(len) + return len def val__len__(self): return len(self._validation) From 9b155348126d6d63b72488ccb3edc79db92fbfd7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 11:05:39 +0530 Subject: [PATCH 18/21] print len --- enhancer/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index c808c8e..c48aded 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -361,7 +361,7 @@ class EnhancerDataset(TaskDataset): else: train_data = worker_info.dataset.data len = sum([len(item) for item in train_data]) // (self.batch_size) - print(len) + print("workers", len) return len def val__len__(self): From 05e40f84b6eefee8831508e90252cd8f9cca7bc2 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 11:17:22 +0530 Subject: [PATCH 19/21] replace pesq --- enhancer/loss.py | 10 +++++++--- requirements.txt | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/enhancer/loss.py b/enhancer/loss.py index 5092656..f2be8df 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch import torch.nn as nn -from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality +from pesq import pesq from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility @@ -122,7 +122,6 @@ class Pesq: self.sr = sr self.name = "pesq" self.mode = mode - self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) def __call__(self, prediction: torch.Tensor, target: torch.Tensor): @@ -130,7 +129,12 @@ class Pesq: for pred, target_ in zip(prediction, target): try: pesq_values.append( - self.pesq(pred.squeeze(), target_.squeeze()).item() + pesq( + self.sr, + target_.squeeze().detach().numpy(), + pred.squeeze().detach().numpy(), + self.mode, + ) ) except Exception as e: logging.warning(f"{e} error occured while calculating PESQ") diff --git a/requirements.txt b/requirements.txt index 95f145d..fa5e41c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ joblib>=1.2.0 librosa>=0.9.2 mlflow>=1.29.0 numpy>=1.23.3 -pesq==0.0.4 +git+https://github.com/ludlows/python-pesq#egg=pesq protobuf>=3.19.6 pystoi==0.3.3 pytest-lazy-fixture>=0.6.3 From 5f1ed8c725931d0a59df7f0eb163af5db93665cf Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 11:17:37 +0530 Subject: [PATCH 20/21] iterable dataset --- enhancer/data/dataset.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index c48aded..80055b6 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -179,13 +179,11 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - names = [] output = {"noisy": [], "clean": []} for item in batch: output["noisy"].append(item["noisy"]) output["clean"].append(item["clean"]) - names.append(item["name"]) - print(names) + output["clean"] = torch.stack(output["clean"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0) return output @@ -355,14 +353,7 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - train_data = self.train_data - else: - train_data = worker_info.dataset.data - len = sum([len(item) for item in train_data]) // (self.batch_size) - print("workers", len) - return len + return sum([len(item) for item in self.train_data]) // (self.batch_size) def val__len__(self): return len(self._validation) From 9f658424a6ae23ea9b6f0048f9d6323e92c141d7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 22 Oct 2022 11:18:32 +0530 Subject: [PATCH 21/21] rmv slicing --- enhancer/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 80055b6..fad1e92 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -150,7 +150,7 @@ class TaskDataset(pl.LightningDataModule): ({"clean": clean, "noisy": noisy}, start) ) train_data.append(samples_metadata) - return train_data[:25] + return train_data @staticmethod def get_num_segments(file_duration, duration, stride):