Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-17 13:10:54 +05:30
commit 77e5a14908
1 changed files with 30 additions and 9 deletions

View File

@ -5,7 +5,6 @@ from typing import Optional
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data import DataLoader, Dataset, IterableDataset
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
@ -54,7 +53,7 @@ class TaskDataset(pl.LightningDataModule):
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_size: float = 0.20, valid_minutes: float = 0.20,
duration: float = 1.0, duration: float = 1.0,
sampling_rate: int = 48000, sampling_rate: int = 48000,
matching_function=None, matching_function=None,
@ -73,10 +72,10 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
if valid_size > 0.0: if valid_minutes > 0.0:
self.valid_size = valid_size self.valid_minutes = valid_minutes
else: else:
raise ValueError("valid_size must be greater than 0") raise ValueError("valid_minutes must be greater than 0")
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
""" """
@ -91,8 +90,8 @@ class TaskDataset(pl.LightningDataModule):
self.name, train_clean, train_noisy, self.matching_function self.name, train_clean, train_noisy, self.matching_function
) )
train_data = fp.prepare_matching_dict() train_data = fp.prepare_matching_dict()
self.train_data, self.val_data = train_test_split( self.train_data, self.val_data = self.train_valid_split(
train_data, test_size=0.20, shuffle=True, random_state=42 train_data, valid_minutes=self.valid_minutes, random_state=42
) )
self._validation = self.prepare_mapstype(self.val_data) self._validation = self.prepare_mapstype(self.val_data)
@ -105,6 +104,28 @@ class TaskDataset(pl.LightningDataModule):
test_data = fp.prepare_matching_dict() test_data = fp.prepare_matching_dict()
self._test = self.prepare_mapstype(test_data) self._test = self.prepare_mapstype(test_data)
def train_valid_split(
self, data, valid_minutes: float = 20, random_state: int = 42
):
valid_minutes *= 60
valid_min_now = 0.0
valid_indices = []
random_indices = list(range(0, len(data)))
rng = create_unique_rng(random_state)
rng.shuffle(random_indices)
i = 0
while valid_min_now <= valid_minutes:
valid_indices.append(random_indices[i])
valid_min_now += data[random_indices[i]]["duration"]
i += 1
train_data = [
item for i, item in enumerate(data) if i not in valid_indices
]
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
return train_data, valid_data
def prepare_mapstype(self, data): def prepare_mapstype(self, data):
metadata = [] metadata = []
@ -172,7 +193,7 @@ class EnhancerDataset(TaskDataset):
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_size=0.2, valid_minutes=5.0,
duration=1.0, duration=1.0,
sampling_rate=48000, sampling_rate=48000,
matching_function=None, matching_function=None,
@ -184,7 +205,7 @@ class EnhancerDataset(TaskDataset):
name=name, name=name,
root_dir=root_dir, root_dir=root_dir,
files=files, files=files,
valid_size=valid_size, valid_minutes=valid_minutes,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
duration=duration, duration=duration,
matching_function=matching_function, matching_function=matching_function,