diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 7f7ae67..08b402f 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,14 +1,12 @@ import math import multiprocessing import os -import random -from itertools import chain, cycle from typing import Optional import pytorch_lightning as pl import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset, IterableDataset +from torch.utils.data import DataLoader, Dataset from enhancer.data.fileprocessor import Fileprocessor from enhancer.utils import check_files @@ -16,13 +14,15 @@ from enhancer.utils.config import Files from enhancer.utils.io import Audio from enhancer.utils.random import create_unique_rng +LARGE_NUM = 2147483647 -class TrainDataset(IterableDataset): + +class TrainDataset(Dataset): def __init__(self, dataset): self.dataset = dataset - def __iter__(self): - return self.dataset.train__iter__() + def __getitem__(self, idx): + return self.dataset.train__getitem__(idx) def __len__(self): return self.dataset.train__len__() @@ -135,16 +135,11 @@ class TaskDataset(pl.LightningDataModule): def prepare_traindata(self, data): train_data = [] for item in data: - samples_metadata = [] clean, noisy, total_dur = item.values() num_segments = self.get_num_segments( total_dur, self.duration, self.stride ) - for index in range(num_segments): - start = index * self.stride - samples_metadata.append( - ({"clean": clean, "noisy": noisy}, start) - ) + samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments) train_data.append(samples_metadata) return train_data @@ -175,31 +170,20 @@ class TaskDataset(pl.LightningDataModule): return metadata def train_collatefn(self, batch): - output = {"noisy": [], "clean": []} - for item in batch: - output["noisy"].append(item["noisy"]) - output["clean"].append(item["clean"]) + raise NotImplementedError("Not implemented") - output["clean"] = torch.stack(output["clean"], dim=0) - output["noisy"] = torch.stack(output["noisy"], dim=0) - return output - - def worker_init_fn(self, _): - worker_info = torch.utils.data.get_worker_info() - dataset = worker_info.dataset - worker_id = worker_info.id - split_size = len(dataset.dataset.train_data) // worker_info.num_workers - dataset.data = dataset.dataset.train_data[ - worker_id * split_size : (worker_id + 1) * split_size - ] + @property + def generator(self): + generator = torch.Generator() + generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM) + return generator def train_dataloader(self): return DataLoader( TrainDataset(self), - batch_size=None, + batch_size=self.batch_size, num_workers=self.num_workers, - collate_fn=self.train_collatefn, - worker_init_fn=self.worker_init_fn, + generator=self.generator, ) def val_dataloader(self): @@ -280,35 +264,16 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) - def random_sample(self, train_data): - return random.sample(train_data, len(train_data)) + def train__getitem__(self, idx): - def train__iter__(self): - rng = create_unique_rng(self.model.current_epoch) - train_data = rng.sample(self.train_data, len(self.train_data)) - return zip( - *[ - self.get_stream(self.random_sample(train_data)) - for i in range(self.batch_size) - ] - ) - - def get_stream(self, data): - return chain.from_iterable(map(self.process_data, cycle(data))) - - def process_data(self, data): - for item in data: - yield self.prepare_segment(*item) - - @staticmethod - def get_num_segments(file_duration, duration, stride): - - if file_duration < duration: - num_segments = 1 - else: - num_segments = math.ceil((file_duration - duration) / stride) + 1 - - return num_segments + for filedict, num_samples in self.train_data: + if idx >= num_samples: + idx -= num_samples + continue + start = 0 + if self.duration is not None: + start = idx * self.stride + return self.prepare_segment(filedict, start) def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) @@ -348,7 +313,8 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - return sum([len(item) for item in self.train_data]) // (self.batch_size) + _, num_examples = list(zip(*self.train_data)) + return sum(num_examples) def val__len__(self): return len(self._validation)