From 5dc5fd8f901e916aabcaa16ba173a353a5cd582e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 24 Oct 2022 21:15:25 +0530 Subject: [PATCH] default stride None --- enhancer/data/dataset.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 08b402f..ab2f1ce 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -14,6 +14,9 @@ from enhancer.utils.config import Files from enhancer.utils.io import Audio from enhancer.utils.random import create_unique_rng +# from torch_audiomentations import Compose + + LARGE_NUM = 2147483647 @@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule): matching_function=None, batch_size=32, num_workers: Optional[int] = None, + # augmentations: Optional[Compose] = None, ): super().__init__() @@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule): else: raise ValueError("valid_minutes must be greater than 0") + # self.augmentations = augmentations + def setup(self, stage: Optional[str] = None): """ prepare train/validation/test data splits @@ -161,7 +167,9 @@ class TaskDataset(pl.LightningDataModule): if total_dur < self.duration: metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) else: - num_segments = round(total_dur / self.duration) + num_segments = self.get_num_segments( + total_dur, self.duration, self.duration + ) for index in range(num_segments): start_time = index * self.duration metadata.append( @@ -175,8 +183,11 @@ class TaskDataset(pl.LightningDataModule): @property def generator(self): generator = torch.Generator() - generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM) - return generator + if hasattr(self, "model"): + seed = self.model.current_epoch + LARGE_NUM + else: + seed = LARGE_NUM + return generator.manual_seed(seed) def train_dataloader(self): return DataLoader( @@ -235,11 +246,12 @@ class EnhancerDataset(TaskDataset): files: Files, valid_minutes=5.0, duration=1.0, - stride=0.5, + stride=None, sampling_rate=48000, matching_function=None, batch_size=32, num_workers: Optional[int] = None, + # augmentations: Optional[Compose] = None, ): super().__init__( @@ -252,6 +264,7 @@ class EnhancerDataset(TaskDataset): matching_function=matching_function, batch_size=batch_size, num_workers=num_workers, + # augmentations=augmentations, ) self.sampling_rate = sampling_rate