diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 28f19a6..8110a2a 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,9 +1,11 @@ import math import multiprocessing import os +from itertools import chain, cycle from typing import Optional import pytorch_lightning as pl +import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, IterableDataset @@ -55,6 +57,7 @@ class TaskDataset(pl.LightningDataModule): files: Files, valid_minutes: float = 0.20, duration: float = 1.0, + stride=None, sampling_rate: int = 48000, matching_function=None, batch_size=32, @@ -65,6 +68,7 @@ class TaskDataset(pl.LightningDataModule): self.name = name self.files, self.root_dir = check_files(root_dir, files) self.duration = duration + self.stride = stride or duration self.sampling_rate = sampling_rate self.batch_size = batch_size self.matching_function = matching_function @@ -90,10 +94,11 @@ class TaskDataset(pl.LightningDataModule): self.name, train_clean, train_noisy, self.matching_function ) train_data = fp.prepare_matching_dict() - self.train_data, self.val_data = self.train_valid_split( + train_data, self.val_data = self.train_valid_split( train_data, valid_minutes=self.valid_minutes, random_state=42 ) + self.train_data = self.prepare_traindata(train_data) self._validation = self.prepare_mapstype(self.val_data) test_clean = os.path.join(self.root_dir, self.files.test_clean) @@ -112,7 +117,7 @@ class TaskDataset(pl.LightningDataModule): valid_min_now = 0.0 valid_indices = [] random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state) + rng = create_unique_rng(random_state, 0) rng.shuffle(random_indices) i = 0 while valid_min_now <= valid_minutes: @@ -126,6 +131,33 @@ class TaskDataset(pl.LightningDataModule): valid_data = [item for i, item in enumerate(data) if i in valid_indices] return train_data, valid_data + def prepare_traindata(self, data): + train_data = [] + for item in data: + samples_metadata = [] + clean, noisy, total_dur = item.values() + num_segments = self.get_num_segments( + total_dur, self.duration, self.stride + ) + for index in range(num_segments): + start = index * self.stride + samples_metadata.append( + ({"clean": clean, "noisy": noisy}, start) + ) + train_data.append(samples_metadata) + print(train_data[:10]) + return train_data + + @staticmethod + def get_num_segments(file_duration, duration, stride): + + if file_duration < duration: + num_segments = 1 + else: + num_segments = math.ceil((file_duration - duration) / stride) + 1 + + return num_segments + def prepare_mapstype(self, data): metadata = [] @@ -142,11 +174,33 @@ class TaskDataset(pl.LightningDataModule): ) return metadata + def train_collatefn(self, batch): + + output = {"noisy": [], "clean": []} + for item in batch: + output["noisy"].append(item["noisy"]) + output["clean"].append(item["clean"]) + + output["clean"] = torch.stack(output["clean"], dim=0) + output["noisy"] = torch.stack(output["noisy"], dim=0) + return output + + def worker_init_fn(self): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + worker_id = worker_info.id + split_size = len(dataset.data) // worker_info.num_workers + dataset.data = dataset.data[ + worker_id * split_size : (worker_id + 1) * split_size + ] + def train_dataloader(self): return DataLoader( TrainDataset(self), - batch_size=self.batch_size, + batch_size=None, num_workers=self.num_workers, + collate_fn=self.train_collatefn, + worker_init_fn=self.worker_init_fn, ) def val_dataloader(self): @@ -227,24 +281,24 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) + def random_sample(self, index): + rng = create_unique_rng(self.model.current_epoch, index) + return rng.sample(self.train_data, len(self.train_data)) + def train__iter__(self): + return zip( + *[ + self.get_stream(self.random_sample(i)) + for i in range(self.batch_size) + ] + ) - rng = create_unique_rng(self.model.current_epoch) + def get_stream(self, data): + return chain.from_iterable(map(self.process_data, cycle(data))) - while True: - - file_dict, *_ = rng.choices( - self.train_data, - k=1, - weights=[file["duration"] for file in self.train_data], - ) - file_duration = file_dict["duration"] - num_segments = self.get_num_segments( - file_duration, self.duration, self.stride - ) - for index in range(0, num_segments): - start_time = index * self.stride - yield self.prepare_segment(file_dict, start_time) + def process_data(self, data): + for item in data: + yield self.prepare_segment(*item) @staticmethod def get_num_segments(file_duration, duration, stride): @@ -264,6 +318,7 @@ class EnhancerDataset(TaskDataset): def prepare_segment(self, file_dict: dict, start_time: float): + print(file_dict["clean"].split("/")[-1], "->", start_time) clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration ) @@ -292,16 +347,7 @@ class EnhancerDataset(TaskDataset): def train__len__(self): - return math.ceil( - sum( - [ - self.get_num_segments( - file["duration"], self.duration, self.stride - ) - for file in self.train_data - ] - ) - ) + return sum([len(item) for item in self.train_data]) def val__len__(self): return len(self._validation)