diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index e370526..ada1ff3 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -8,7 +8,7 @@ import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, RandomSampler from torch_audiomentations import Compose from enhancer.data.fileprocessor import Fileprocessor @@ -135,6 +135,7 @@ class TaskDataset(pl.LightningDataModule): speaker_index = rng.choice(possible_indices) possible_indices.remove(speaker_index) speaker_name = all_speakers[speaker_index] + print(f"Selected f{speaker_name} for valid") file_indices = [ i for i, file in enumerate(data) @@ -222,11 +223,13 @@ class TaskDataset(pl.LightningDataModule): return generator.manual_seed(seed) def train_dataloader(self): + dataset = TrainDataset(self) + sampler = RandomSampler(dataset, generator=self.generator) return DataLoader( - TrainDataset(self), + dataset, batch_size=self.batch_size, num_workers=self.num_workers, - generator=self.generator, + sampler=sampler, collate_fn=self.train_collatefn, ) @@ -263,8 +266,6 @@ class EnhancerDataset(TaskDataset): expected audio duration of single audio sample for training sampling_rate : int desired sampling rate - padding_mode: str - padding mode (silent,reflect) batch_size : int batch size of each batch num_workers : int @@ -287,7 +288,6 @@ class EnhancerDataset(TaskDataset): duration=1.0, stride=None, sampling_rate=48000, - padding_mode: str = "silent", matching_function=None, batch_size=32, num_workers: Optional[int] = None, @@ -312,7 +312,6 @@ class EnhancerDataset(TaskDataset): self.duration = max(1.0, duration) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) self.stride = stride or duration - self.padding_mode = padding_mode def setup(self, stage: Optional[str] = None): @@ -337,7 +336,6 @@ class EnhancerDataset(TaskDataset): return self.prepare_segment(*self._test[idx]) def prepare_segment(self, file_dict: dict, start_time: float): - clean_segment = self.audio( file_dict["clean"], offset=start_time, duration=self.duration ) @@ -362,7 +360,6 @@ class EnhancerDataset(TaskDataset): self.duration * self.sampling_rate - noisy_segment.shape[-1] ), ), - mode=self.padding_mode, ) return { "clean": clean_segment,