diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 95c73a1..d2b7526 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -5,6 +5,7 @@ from typing import Optional import pytorch_lightning as pl import torch.nn.functional as F +from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, Dataset, IterableDataset from enhancer.data.fileprocessor import Fileprocessor @@ -36,12 +37,24 @@ class ValidDataset(Dataset): return self.dataset.val__len__() +class TestDataset(Dataset): + def __init__(self, dataset): + self.dataset = dataset + + def __getitem__(self, idx): + return self.dataset.test__getitem__(idx) + + def __len__(self): + return self.dataset.test__len__() + + class TaskDataset(pl.LightningDataModule): def __init__( self, name: str, root_dir: str, files: Files, + valid_size: float = 0.20, duration: float = 1.0, sampling_rate: int = 48000, matching_function=None, @@ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule): if num_workers is None: num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers + if valid_size > 0.0: + self.valid_size = valid_size + else: + raise ValueError("valid_size must be greater than 0") def setup(self, stage: Optional[str] = None): + """ + prepare train/validation/test data splits + """ if stage in ("fit", None): @@ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule): fp = Fileprocessor.from_name( self.name, train_clean, train_noisy, self.matching_function ) - self.train_data = fp.prepare_matching_dict() - - val_clean = os.path.join(self.root_dir, self.files.test_clean) - val_noisy = os.path.join(self.root_dir, self.files.test_noisy) - fp = Fileprocessor.from_name( - self.name, val_clean, val_noisy, self.matching_function + train_data = fp.prepare_matching_dict() + self.train_data, self.val_data = train_test_split( + train_data, test_size=0.20, shuffle=True, random_state=42 ) - val_data = fp.prepare_matching_dict() - for item in val_data: - clean, noisy, total_dur = item.values() - if total_dur < self.duration: - continue - num_segments = round(total_dur / self.duration) - for index in range(num_segments): - start_time = index * self.duration - self._validation.append( - ({"clean": clean, "noisy": noisy}, start_time) - ) + self._validation = self.prepare_mapstype(self.val_data) + + test_clean = os.path.join(self.root_dir, self.files.test_clean) + test_noisy = os.path.join(self.root_dir, self.files.test_noisy) + fp = Fileprocessor.from_name( + self.name, test_clean, test_noisy, self.matching_function + ) + test_data = fp.prepare_matching_dict() + self._test = self.prepare_mapstype(test_data) + + def prepare_mapstype(self, data): + + metadata = [] + for item in data: + clean, noisy, total_dur = item.values() + if total_dur < self.duration: + continue + num_segments = round(total_dur / self.duration) + for index in range(num_segments): + start_time = index * self.duration + metadata.append(({"clean": clean, "noisy": noisy}, start_time)) + return metadata def train_dataloader(self): return DataLoader( @@ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule): num_workers=self.num_workers, ) + def test_dataloader(self): + return DataLoader( + TestDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + class EnhancerDataset(TaskDataset): """ @@ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset): name: str, root_dir: str, files: Files, + valid_size=0.2, duration=1.0, sampling_rate=48000, matching_function=None, @@ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset): name=name, root_dir=root_dir, files=files, + valid_size=valid_size, sampling_rate=sampling_rate, duration=duration, matching_function=matching_function, @@ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset): def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) + def test__getitem__(self, idx): + return self.prepare_segment(*self._test[idx]) + def prepare_segment(self, file_dict: dict, start_time: float): clean_segment = self.audio( @@ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset): def val__len__(self): return len(self._validation) + + def test__len__(self): + return len(self._test) diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 03afc73..e718f15 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -55,7 +55,7 @@ class ProcessorFunctions: One clean audio have multiple noisy audio files """ - matching_wavfiles = dict() + matching_wavfiles = list() clean_filenames = [ file.split("/")[-1] for file in glob.glob(os.path.join(clean_path, "*.wav")) @@ -73,7 +73,7 @@ class ProcessorFunctions: if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( sr_clean == sr_noisy ): - matching_wavfiles.update( + matching_wavfiles.append( { "clean": os.path.join(clean_path, clean_file), "noisy": noisy_file,