diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index d2b7526..02a1d3b 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -5,7 +5,6 @@ from typing import Optional import pytorch_lightning as pl import torch.nn.functional as F -from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, Dataset, IterableDataset from enhancer.data.fileprocessor import Fileprocessor @@ -54,7 +53,7 @@ class TaskDataset(pl.LightningDataModule): name: str, root_dir: str, files: Files, - valid_size: float = 0.20, + valid_minutes: float = 0.20, duration: float = 1.0, sampling_rate: int = 48000, matching_function=None, @@ -73,10 +72,10 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers - if valid_size > 0.0: - self.valid_size = valid_size + if valid_minutes > 0.0: + self.valid_minutes = valid_minutes else: - raise ValueError("valid_size must be greater than 0") + raise ValueError("valid_minutes must be greater than 0") def setup(self, stage: Optional[str] = None): """ @@ -91,8 +90,8 @@ class TaskDataset(pl.LightningDataModule): self.name, train_clean, train_noisy, self.matching_function ) train_data = fp.prepare_matching_dict() - self.train_data, self.val_data = train_test_split( - train_data, test_size=0.20, shuffle=True, random_state=42 + self.train_data, self.val_data = self.train_valid_split( + train_data, valid_minutes=self.valid_minutes, random_state=42 ) self._validation = self.prepare_mapstype(self.val_data) @@ -105,6 +104,28 @@ class TaskDataset(pl.LightningDataModule): test_data = fp.prepare_matching_dict() self._test = self.prepare_mapstype(test_data) + def train_valid_split( + self, data, valid_minutes: float = 20, random_state: int = 42 + ): + + valid_minutes *= 60 + valid_min_now = 0.0 + valid_indices = [] + random_indices = list(range(0, len(data))) + rng = create_unique_rng(random_state) + rng.shuffle(random_indices) + i = 0 + while valid_min_now <= valid_minutes: + valid_indices.append(random_indices[i]) + valid_min_now += data[random_indices[i]]["duration"] + i += 1 + + train_data = [ + item for i, item in enumerate(data) if i not in valid_indices + ] + valid_data = [item for i, item in enumerate(data) if i in valid_indices] + return train_data, valid_data + def prepare_mapstype(self, data): metadata = [] @@ -172,7 +193,7 @@ class EnhancerDataset(TaskDataset): name: str, root_dir: str, files: Files, - valid_size=0.2, + valid_minutes=5.0, duration=1.0, sampling_rate=48000, matching_function=None, @@ -184,7 +205,7 @@ class EnhancerDataset(TaskDataset): name=name, root_dir=root_dir, files=files, - valid_size=valid_size, + valid_minutes=valid_minutes, sampling_rate=sampling_rate, duration=duration, matching_function=matching_function,