From aa52d1ed93293d0f22f47c5447b576197e770aa5 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 28 Oct 2022 13:06:49 +0530 Subject: [PATCH] add random sampler --- enhancer/data/dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 8851ea2..d6ab415 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -8,7 +8,7 @@ import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, RandomSampler from torch_audiomentations import Compose from enhancer.data.fileprocessor import Fileprocessor @@ -133,6 +133,7 @@ class TaskDataset(pl.LightningDataModule): speaker_index = rng.choice(possible_indices) possible_indices.remove(speaker_index) speaker_name = all_speakers[speaker_index] + print(f"Selected f{speaker_name} for valid") file_indices = [ i for i, file in enumerate(data) @@ -220,11 +221,13 @@ class TaskDataset(pl.LightningDataModule): return generator.manual_seed(seed) def train_dataloader(self): + dataset = TrainDataset(self) + sampler = RandomSampler(dataset, generator=self.generator) return DataLoader( - TrainDataset(self), + dataset, batch_size=self.batch_size, num_workers=self.num_workers, - generator=self.generator, + sampler=sampler, collate_fn=self.train_collatefn, ) @@ -327,7 +330,6 @@ class EnhancerDataset(TaskDataset): return self.prepare_segment(*self._test[idx]) def prepare_segment(self, file_dict: dict, start_time: float): - clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration )