From 6f1acf0423525f3e66f3b648e948aafffcf5f869 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 29 Oct 2022 10:33:59 +0530 Subject: [PATCH] Revert "add random sampler" This reverts commit aa52d1ed93293d0f22f47c5447b576197e770aa5. --- enhancer/data/dataset.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index a6b6ba1..e370526 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -8,7 +8,7 @@ import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset, RandomSampler +from torch.utils.data import DataLoader, Dataset from torch_audiomentations import Compose from enhancer.data.fileprocessor import Fileprocessor @@ -135,7 +135,6 @@ class TaskDataset(pl.LightningDataModule): speaker_index = rng.choice(possible_indices) possible_indices.remove(speaker_index) speaker_name = all_speakers[speaker_index] - print(f"Selected f{speaker_name} for valid") file_indices = [ i for i, file in enumerate(data) @@ -223,13 +222,11 @@ class TaskDataset(pl.LightningDataModule): return generator.manual_seed(seed) def train_dataloader(self): - dataset = TrainDataset(self) - sampler = RandomSampler(dataset, generator=self.generator) return DataLoader( - dataset, + TrainDataset(self), batch_size=self.batch_size, num_workers=self.num_workers, - sampler=sampler, + generator=self.generator, collate_fn=self.train_collatefn, ) @@ -340,6 +337,7 @@ class EnhancerDataset(TaskDataset): return self.prepare_segment(*self._test[idx]) def prepare_segment(self, file_dict: dict, start_time: float): + clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration )