diff --git a/.gitignore b/.gitignore index 9cd222c..cd1b1e9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ #local *.ckpt +*_local.yaml cli/train_config/dataset/Vctk_local.yaml .DS_Store outputs/ diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index de52647..4fa4438 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -9,7 +9,7 @@ benchmark: False check_val_every_n_epoch: 1 detect_anomaly: False deterministic: False -devices: -1 +devices: 1 enable_checkpointing: True enable_model_summary: True enable_progress_bar: True diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index d140174..bc1f0a2 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,9 +1,12 @@ import math import multiprocessing import os +import random +from itertools import chain, cycle from typing import Optional import pytorch_lightning as pl +import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -55,6 +58,7 @@ class TaskDataset(pl.LightningDataModule): files: Files, valid_minutes: float = 0.20, duration: float = 1.0, + stride=None, sampling_rate: int = 48000, matching_function=None, batch_size=32, @@ -65,6 +69,7 @@ class TaskDataset(pl.LightningDataModule): self.name = name self.files, self.root_dir = check_files(root_dir, files) self.duration = duration + self.stride = stride or duration self.sampling_rate = sampling_rate self.batch_size = batch_size self.matching_function = matching_function @@ -72,6 +77,7 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers + print("num_workers-main", self.num_workers) if valid_minutes > 0.0: self.valid_minutes = valid_minutes else: @@ -90,11 +96,15 @@ class TaskDataset(pl.LightningDataModule): self.name, train_clean, train_noisy, self.matching_function ) train_data = fp.prepare_matching_dict() - self.train_data, self.val_data = self.train_valid_split( + train_data, self.val_data = self.train_valid_split( train_data, valid_minutes=self.valid_minutes, random_state=42 ) + self.train_data = self.prepare_traindata(train_data) self._validation = self.prepare_mapstype(self.val_data) + print( + "train_data_size", sum([len(item) for item in self.train_data]) + ) test_clean = os.path.join(self.root_dir, self.files.test_clean) test_noisy = os.path.join(self.root_dir, self.files.test_noisy) @@ -126,6 +136,32 @@ class TaskDataset(pl.LightningDataModule): valid_data = [item for i, item in enumerate(data) if i in valid_indices] return train_data, valid_data + def prepare_traindata(self, data): + train_data = [] + for item in data: + samples_metadata = [] + clean, noisy, total_dur = item.values() + num_segments = self.get_num_segments( + total_dur, self.duration, self.stride + ) + for index in range(num_segments): + start = index * self.stride + samples_metadata.append( + ({"clean": clean, "noisy": noisy}, start) + ) + train_data.append(samples_metadata) + return train_data + + @staticmethod + def get_num_segments(file_duration, duration, stride): + + if file_duration < duration: + num_segments = 1 + else: + num_segments = math.ceil((file_duration - duration) / stride) + 1 + + return num_segments + def prepare_mapstype(self, data): metadata = [] @@ -142,11 +178,32 @@ class TaskDataset(pl.LightningDataModule): ) return metadata + def train_collatefn(self, batch): + output = {"noisy": [], "clean": []} + for item in batch: + output["noisy"].append(item["noisy"]) + output["clean"].append(item["clean"]) + + output["clean"] = torch.stack(output["clean"], dim=0) + output["noisy"] = torch.stack(output["noisy"], dim=0) + return output + + def worker_init_fn(self, _): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + worker_id = worker_info.id + split_size = len(dataset.dataset.train_data) // worker_info.num_workers + dataset.data = dataset.dataset.train_data[ + worker_id * split_size : (worker_id + 1) * split_size + ] + def train_dataloader(self): return DataLoader( TrainDataset(self), - batch_size=self.batch_size, + batch_size=None, num_workers=self.num_workers, + collate_fn=self.train_collatefn, + worker_init_fn=self.worker_init_fn, ) def val_dataloader(self): @@ -227,24 +284,25 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) + def random_sample(self, train_data): + return random.sample(train_data, len(train_data)) + def train__iter__(self): - rng = create_unique_rng(self.model.current_epoch) + train_data = rng.sample(self.train_data, len(self.train_data)) + return zip( + *[ + self.get_stream(self.random_sample(train_data)) + for i in range(self.batch_size) + ] + ) - while True: + def get_stream(self, data): + return chain.from_iterable(map(self.process_data, cycle(data))) - file_dict, *_ = rng.choices( - self.train_data, - k=1, - weights=[file["duration"] for file in self.train_data], - ) - file_duration = file_dict["duration"] - num_segments = self.get_num_segments( - file_duration, self.duration, self.stride - ) - for index in range(0, num_segments): - start_time = index * self.stride - yield self.prepare_segment(file_dict, start_time) + def process_data(self, data): + for item in data: + yield self.prepare_segment(*item) @staticmethod def get_num_segments(file_duration, duration, stride): @@ -288,20 +346,14 @@ class EnhancerDataset(TaskDataset): ), ), ) - return {"clean": clean_segment, "noisy": noisy_segment} + return { + "clean": clean_segment, + "noisy": noisy_segment, + "name": file_dict["clean"].split("/")[-1] + "->" + str(start_time), + } def train__len__(self): - - return math.ceil( - sum( - [ - self.get_num_segments( - file["duration"], self.duration, self.stride - ) - for file in self.train_data - ] - ) - ) + return sum([len(item) for item in self.train_data]) // (self.batch_size) def val__len__(self): return len(self._validation) diff --git a/enhancer/loss.py b/enhancer/loss.py index 5092656..f2be8df 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch import torch.nn as nn -from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality +from pesq import pesq from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility @@ -122,7 +122,6 @@ class Pesq: self.sr = sr self.name = "pesq" self.mode = mode - self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) def __call__(self, prediction: torch.Tensor, target: torch.Tensor): @@ -130,7 +129,12 @@ class Pesq: for pred, target_ in zip(prediction, target): try: pesq_values.append( - self.pesq(pred.squeeze(), target_.squeeze()).item() + pesq( + self.sr, + target_.squeeze().detach().numpy(), + pred.squeeze().detach().numpy(), + self.mode, + ) ) except Exception as e: logging.warning(f"{e} error occured while calculating PESQ") diff --git a/requirements.txt b/requirements.txt index 95f145d..fa5e41c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ joblib>=1.2.0 librosa>=0.9.2 mlflow>=1.29.0 numpy>=1.23.3 -pesq==0.0.4 +git+https://github.com/ludlows/python-pesq#egg=pesq protobuf>=3.19.6 pystoi==0.3.3 pytest-lazy-fixture>=0.6.3