specify valid size in mins
This commit is contained in:
parent
dab7e73d53
commit
e118c31f18
|
|
@ -5,7 +5,6 @@ from typing import Optional
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
|
@ -54,7 +53,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_size: float = 0.20,
|
valid_minutes: float = 0.20,
|
||||||
duration: float = 1.0,
|
duration: float = 1.0,
|
||||||
sampling_rate: int = 48000,
|
sampling_rate: int = 48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
|
|
@ -73,10 +72,10 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
if valid_size > 0.0:
|
if valid_minutes > 0.0:
|
||||||
self.valid_size = valid_size
|
self.valid_minutes = valid_minutes
|
||||||
else:
|
else:
|
||||||
raise ValueError("valid_size must be greater than 0")
|
raise ValueError("valid_minutes must be greater than 0")
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -91,8 +90,8 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self.name, train_clean, train_noisy, self.matching_function
|
self.name, train_clean, train_noisy, self.matching_function
|
||||||
)
|
)
|
||||||
train_data = fp.prepare_matching_dict()
|
train_data = fp.prepare_matching_dict()
|
||||||
self.train_data, self.val_data = train_test_split(
|
self.train_data, self.val_data = self.train_valid_split(
|
||||||
train_data, test_size=0.20, shuffle=True, random_state=42
|
train_data, valid_minutes=self.valid_minutes, random_state=42
|
||||||
)
|
)
|
||||||
|
|
||||||
self._validation = self.prepare_mapstype(self.val_data)
|
self._validation = self.prepare_mapstype(self.val_data)
|
||||||
|
|
@ -105,6 +104,28 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
test_data = fp.prepare_matching_dict()
|
test_data = fp.prepare_matching_dict()
|
||||||
self._test = self.prepare_mapstype(test_data)
|
self._test = self.prepare_mapstype(test_data)
|
||||||
|
|
||||||
|
def train_valid_split(
|
||||||
|
self, data, valid_minutes: float = 20, random_state: int = 42
|
||||||
|
):
|
||||||
|
|
||||||
|
valid_minutes *= 60
|
||||||
|
valid_min_now = 0.0
|
||||||
|
valid_indices = []
|
||||||
|
random_indices = list(range(0, len(data)))
|
||||||
|
rng = create_unique_rng(random_state)
|
||||||
|
rng.shuffle(random_indices)
|
||||||
|
i = 0
|
||||||
|
while valid_min_now <= valid_minutes:
|
||||||
|
valid_indices.append(random_indices[i])
|
||||||
|
valid_min_now += data[random_indices[i]]["duration"]
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
train_data = [
|
||||||
|
item for i, item in enumerate(data) if i not in valid_indices
|
||||||
|
]
|
||||||
|
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
|
||||||
|
return train_data, valid_data
|
||||||
|
|
||||||
def prepare_mapstype(self, data):
|
def prepare_mapstype(self, data):
|
||||||
|
|
||||||
metadata = []
|
metadata = []
|
||||||
|
|
@ -172,7 +193,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_size=0.2,
|
valid_minutes=5.0,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
|
|
@ -184,7 +205,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name=name,
|
name=name,
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
files=files,
|
files=files,
|
||||||
valid_size=valid_size,
|
valid_minutes=valid_minutes,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue