diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index e2833b9..d6ab415 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -8,7 +8,7 @@ import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, RandomSampler from torch_audiomentations import Compose from enhancer.data.fileprocessor import Fileprocessor @@ -221,11 +221,13 @@ class TaskDataset(pl.LightningDataModule): return generator.manual_seed(seed) def train_dataloader(self): + dataset = TrainDataset(self) + sampler = RandomSampler(dataset, generator=self.generator) return DataLoader( - TrainDataset(self), + dataset, batch_size=self.batch_size, num_workers=self.num_workers, - generator=self.generator, + sampler=sampler, collate_fn=self.train_collatefn, ) @@ -328,7 +330,6 @@ class EnhancerDataset(TaskDataset): return self.prepare_segment(*self._test[idx]) def prepare_segment(self, file_dict: dict, start_time: float): - print(file_dict["clean"].split("/")[-1]) clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration )