diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index d6ab415..a6b6ba1 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule): name: str, root_dir: str, files: Files, - valid_minutes: float = 0.20, + min_valid_minutes: float = 0.20, duration: float = 1.0, stride=None, sampling_rate: int = 48000, @@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers - if valid_minutes > 0.0: - self.valid_minutes = valid_minutes + if min_valid_minutes > 0.0: + self.min_valid_minutes = min_valid_minutes else: - raise ValueError("valid_minutes must be greater than 0") + raise ValueError("min_valid_minutes must be greater than 0") self.augmentations = augmentations @@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule): ) train_data = fp.prepare_matching_dict() train_data, self.val_data = self.train_valid_split( - train_data, valid_minutes=self.valid_minutes, random_state=42 + train_data, + min_valid_minutes=self.min_valid_minutes, + random_state=42, ) self.train_data = self.prepare_traindata(train_data) @@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule): self._test = self.prepare_mapstype(test_data) def train_valid_split( - self, data, valid_minutes: float = 20, random_state: int = 42 + self, data, min_valid_minutes: float = 20, random_state: int = 42 ): - valid_minutes *= 60 + min_valid_minutes *= 60 valid_sec_now = 0.0 valid_indices = [] all_speakers = np.unique( @@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule): possible_indices = list(range(0, len(all_speakers))) rng = create_unique_rng(len(all_speakers)) - while valid_sec_now <= valid_minutes: + while valid_sec_now <= min_valid_minutes: speaker_index = rng.choice(possible_indices) possible_indices.remove(speaker_index) speaker_name = all_speakers[speaker_index] @@ -257,10 +259,15 @@ class EnhancerDataset(TaskDataset): files : Files dataclass containing train_clean, train_noisy, test_clean, test_noisy folder names (refer enhancer.utils.Files dataclass) + min_valid_minutes: float + minimum validation split size time in minutes + algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data. duration : float expected audio duration of single audio sample for training sampling_rate : int desired sampling rate + padding_mode: str + padding mode (silent,reflect) batch_size : int batch size of each batch num_workers : int @@ -271,6 +278,7 @@ class EnhancerDataset(TaskDataset): use one_to_many mapping for multiple noisy files for each clean file + """ def __init__( @@ -278,10 +286,11 @@ class EnhancerDataset(TaskDataset): name: str, root_dir: str, files: Files, - valid_minutes=5.0, + min_valid_minutes=5.0, duration=1.0, stride=None, sampling_rate=48000, + padding_mode: str = "silent", matching_function=None, batch_size=32, num_workers: Optional[int] = None, @@ -292,7 +301,7 @@ class EnhancerDataset(TaskDataset): name=name, root_dir=root_dir, files=files, - valid_minutes=valid_minutes, + min_valid_minutes=min_valid_minutes, sampling_rate=sampling_rate, duration=duration, matching_function=matching_function, @@ -306,6 +315,7 @@ class EnhancerDataset(TaskDataset): self.duration = max(1.0, duration) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) self.stride = stride or duration + self.padding_mode = padding_mode def setup(self, stage: Optional[str] = None): @@ -344,6 +354,7 @@ class EnhancerDataset(TaskDataset): self.duration * self.sampling_rate - clean_segment.shape[-1] ), ), + mode=self.padding_mode, ) noisy_segment = F.pad( noisy_segment, @@ -353,6 +364,7 @@ class EnhancerDataset(TaskDataset): self.duration * self.sampling_rate - noisy_segment.shape[-1] ), ), + mode=self.padding_mode, ) return { "clean": clean_segment,