add test dataloader

This commit is contained in:
shahules786 2022-10-10 12:45:23 +05:30
parent 37fe86063d
commit 3e654d10a7
2 changed files with 62 additions and 19 deletions

View File

@ -5,6 +5,7 @@ from typing import Optional
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data import DataLoader, Dataset, IterableDataset
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
@ -36,12 +37,24 @@ class ValidDataset(Dataset):
return self.dataset.val__len__() return self.dataset.val__len__()
class TestDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, idx):
return self.dataset.test__getitem__(idx)
def __len__(self):
return self.dataset.test__len__()
class TaskDataset(pl.LightningDataModule): class TaskDataset(pl.LightningDataModule):
def __init__( def __init__(
self, self,
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_size: float = 0.20,
duration: float = 1.0, duration: float = 1.0,
sampling_rate: int = 48000, sampling_rate: int = 48000,
matching_function=None, matching_function=None,
@ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
if valid_size > 0.0:
self.valid_size = valid_size
else:
raise ValueError("valid_size must be greater than 0")
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
"""
prepare train/validation/test data splits
"""
if stage in ("fit", None): if stage in ("fit", None):
@ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule):
fp = Fileprocessor.from_name( fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function self.name, train_clean, train_noisy, self.matching_function
) )
self.train_data = fp.prepare_matching_dict() train_data = fp.prepare_matching_dict()
self.train_data, self.val_data = train_test_split(
val_clean = os.path.join(self.root_dir, self.files.test_clean) train_data, test_size=0.20, shuffle=True, random_state=42
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, val_clean, val_noisy, self.matching_function
) )
val_data = fp.prepare_matching_dict()
for item in val_data: self._validation = self.prepare_mapstype(self.val_data)
test_clean = os.path.join(self.root_dir, self.files.test_clean)
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, test_clean, test_noisy, self.matching_function
)
test_data = fp.prepare_matching_dict()
self._test = self.prepare_mapstype(test_data)
def prepare_mapstype(self, data):
metadata = []
for item in data:
clean, noisy, total_dur = item.values() clean, noisy, total_dur = item.values()
if total_dur < self.duration: if total_dur < self.duration:
continue continue
num_segments = round(total_dur / self.duration) num_segments = round(total_dur / self.duration)
for index in range(num_segments): for index in range(num_segments):
start_time = index * self.duration start_time = index * self.duration
self._validation.append( metadata.append(({"clean": clean, "noisy": noisy}, start_time))
({"clean": clean, "noisy": noisy}, start_time) return metadata
)
def train_dataloader(self): def train_dataloader(self):
return DataLoader( return DataLoader(
@ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule):
num_workers=self.num_workers, num_workers=self.num_workers,
) )
def test_dataloader(self):
return DataLoader(
TestDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset): class EnhancerDataset(TaskDataset):
""" """
@ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset):
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_size=0.2,
duration=1.0, duration=1.0,
sampling_rate=48000, sampling_rate=48000,
matching_function=None, matching_function=None,
@ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset):
name=name, name=name,
root_dir=root_dir, root_dir=root_dir,
files=files, files=files,
valid_size=valid_size,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
duration=duration, duration=duration,
matching_function=matching_function, matching_function=matching_function,
@ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset):
def val__getitem__(self, idx): def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx]) return self.prepare_segment(*self._validation[idx])
def test__getitem__(self, idx):
return self.prepare_segment(*self._test[idx])
def prepare_segment(self, file_dict: dict, start_time: float): def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio( clean_segment = self.audio(
@ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset):
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)
def test__len__(self):
return len(self._test)

View File

@ -55,7 +55,7 @@ class ProcessorFunctions:
One clean audio have multiple noisy audio files One clean audio have multiple noisy audio files
""" """
matching_wavfiles = dict() matching_wavfiles = list()
clean_filenames = [ clean_filenames = [
file.split("/")[-1] file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav")) for file in glob.glob(os.path.join(clean_path, "*.wav"))
@ -73,7 +73,7 @@ class ProcessorFunctions:
if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy sr_clean == sr_noisy
): ):
matching_wavfiles.update( matching_wavfiles.append(
{ {
"clean": os.path.join(clean_path, clean_file), "clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file, "noisy": noisy_file,