add test dataloader
This commit is contained in:
parent
37fe86063d
commit
3e654d10a7
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
|
|
@ -36,12 +37,24 @@ class ValidDataset(Dataset):
|
|||
return self.dataset.val__len__()
|
||||
|
||||
|
||||
class TestDataset(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset.test__getitem__(idx)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset.test__len__()
|
||||
|
||||
|
||||
class TaskDataset(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
valid_size: float = 0.20,
|
||||
duration: float = 1.0,
|
||||
sampling_rate: int = 48000,
|
||||
matching_function=None,
|
||||
|
|
@ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule):
|
|||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
self.num_workers = num_workers
|
||||
if valid_size > 0.0:
|
||||
self.valid_size = valid_size
|
||||
else:
|
||||
raise ValueError("valid_size must be greater than 0")
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
"""
|
||||
prepare train/validation/test data splits
|
||||
"""
|
||||
|
||||
if stage in ("fit", None):
|
||||
|
||||
|
|
@ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule):
|
|||
fp = Fileprocessor.from_name(
|
||||
self.name, train_clean, train_noisy, self.matching_function
|
||||
)
|
||||
self.train_data = fp.prepare_matching_dict()
|
||||
|
||||
val_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(
|
||||
self.name, val_clean, val_noisy, self.matching_function
|
||||
train_data = fp.prepare_matching_dict()
|
||||
self.train_data, self.val_data = train_test_split(
|
||||
train_data, test_size=0.20, shuffle=True, random_state=42
|
||||
)
|
||||
val_data = fp.prepare_matching_dict()
|
||||
|
||||
for item in val_data:
|
||||
self._validation = self.prepare_mapstype(self.val_data)
|
||||
|
||||
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(
|
||||
self.name, test_clean, test_noisy, self.matching_function
|
||||
)
|
||||
test_data = fp.prepare_matching_dict()
|
||||
self._test = self.prepare_mapstype(test_data)
|
||||
|
||||
def prepare_mapstype(self, data):
|
||||
|
||||
metadata = []
|
||||
for item in data:
|
||||
clean, noisy, total_dur = item.values()
|
||||
if total_dur < self.duration:
|
||||
continue
|
||||
num_segments = round(total_dur / self.duration)
|
||||
for index in range(num_segments):
|
||||
start_time = index * self.duration
|
||||
self._validation.append(
|
||||
({"clean": clean, "noisy": noisy}, start_time)
|
||||
)
|
||||
metadata.append(({"clean": clean, "noisy": noisy}, start_time))
|
||||
return metadata
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
|
|
@ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule):
|
|||
num_workers=self.num_workers,
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return DataLoader(
|
||||
TestDataset(self),
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
"""
|
||||
|
|
@ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset):
|
|||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
valid_size=0.2,
|
||||
duration=1.0,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
|
|
@ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset):
|
|||
name=name,
|
||||
root_dir=root_dir,
|
||||
files=files,
|
||||
valid_size=valid_size,
|
||||
sampling_rate=sampling_rate,
|
||||
duration=duration,
|
||||
matching_function=matching_function,
|
||||
|
|
@ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset):
|
|||
def val__getitem__(self, idx):
|
||||
return self.prepare_segment(*self._validation[idx])
|
||||
|
||||
def test__getitem__(self, idx):
|
||||
return self.prepare_segment(*self._test[idx])
|
||||
|
||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||
|
||||
clean_segment = self.audio(
|
||||
|
|
@ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
def val__len__(self):
|
||||
return len(self._validation)
|
||||
|
||||
def test__len__(self):
|
||||
return len(self._test)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class ProcessorFunctions:
|
|||
One clean audio have multiple noisy audio files
|
||||
"""
|
||||
|
||||
matching_wavfiles = dict()
|
||||
matching_wavfiles = list()
|
||||
clean_filenames = [
|
||||
file.split("/")[-1]
|
||||
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||
|
|
@ -73,7 +73,7 @@ class ProcessorFunctions:
|
|||
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||
sr_clean == sr_noisy
|
||||
):
|
||||
matching_wavfiles.update(
|
||||
matching_wavfiles.append(
|
||||
{
|
||||
"clean": os.path.join(clean_path, clean_file),
|
||||
"noisy": noisy_file,
|
||||
|
|
|
|||
Loading…
Reference in New Issue