specify valid size in mins
This commit is contained in:
parent
dab7e73d53
commit
e118c31f18
|
|
@ -5,7 +5,6 @@ from typing import Optional
|
|||
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
|
|
@ -54,7 +53,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
valid_size: float = 0.20,
|
||||
valid_minutes: float = 0.20,
|
||||
duration: float = 1.0,
|
||||
sampling_rate: int = 48000,
|
||||
matching_function=None,
|
||||
|
|
@ -73,10 +72,10 @@ class TaskDataset(pl.LightningDataModule):
|
|||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
self.num_workers = num_workers
|
||||
if valid_size > 0.0:
|
||||
self.valid_size = valid_size
|
||||
if valid_minutes > 0.0:
|
||||
self.valid_minutes = valid_minutes
|
||||
else:
|
||||
raise ValueError("valid_size must be greater than 0")
|
||||
raise ValueError("valid_minutes must be greater than 0")
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
"""
|
||||
|
|
@ -91,8 +90,8 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self.name, train_clean, train_noisy, self.matching_function
|
||||
)
|
||||
train_data = fp.prepare_matching_dict()
|
||||
self.train_data, self.val_data = train_test_split(
|
||||
train_data, test_size=0.20, shuffle=True, random_state=42
|
||||
self.train_data, self.val_data = self.train_valid_split(
|
||||
train_data, valid_minutes=self.valid_minutes, random_state=42
|
||||
)
|
||||
|
||||
self._validation = self.prepare_mapstype(self.val_data)
|
||||
|
|
@ -105,6 +104,28 @@ class TaskDataset(pl.LightningDataModule):
|
|||
test_data = fp.prepare_matching_dict()
|
||||
self._test = self.prepare_mapstype(test_data)
|
||||
|
||||
def train_valid_split(
|
||||
self, data, valid_minutes: float = 20, random_state: int = 42
|
||||
):
|
||||
|
||||
valid_minutes *= 60
|
||||
valid_min_now = 0.0
|
||||
valid_indices = []
|
||||
random_indices = list(range(0, len(data)))
|
||||
rng = create_unique_rng(random_state)
|
||||
rng.shuffle(random_indices)
|
||||
i = 0
|
||||
while valid_min_now <= valid_minutes:
|
||||
valid_indices.append(random_indices[i])
|
||||
valid_min_now += data[random_indices[i]]["duration"]
|
||||
i += 1
|
||||
|
||||
train_data = [
|
||||
item for i, item in enumerate(data) if i not in valid_indices
|
||||
]
|
||||
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
|
||||
return train_data, valid_data
|
||||
|
||||
def prepare_mapstype(self, data):
|
||||
|
||||
metadata = []
|
||||
|
|
@ -172,7 +193,7 @@ class EnhancerDataset(TaskDataset):
|
|||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
valid_size=0.2,
|
||||
valid_minutes=5.0,
|
||||
duration=1.0,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
|
|
@ -184,7 +205,7 @@ class EnhancerDataset(TaskDataset):
|
|||
name=name,
|
||||
root_dir=root_dir,
|
||||
files=files,
|
||||
valid_size=valid_size,
|
||||
valid_minutes=valid_minutes,
|
||||
sampling_rate=sampling_rate,
|
||||
duration=duration,
|
||||
matching_function=matching_function,
|
||||
|
|
|
|||
Loading…
Reference in New Issue