add test dataloader
This commit is contained in:
parent
37fe86063d
commit
3e654d10a7
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
|
@ -36,12 +37,24 @@ class ValidDataset(Dataset):
|
||||||
return self.dataset.val__len__()
|
return self.dataset.val__len__()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataset(Dataset):
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset.test__getitem__(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset.test__len__()
|
||||||
|
|
||||||
|
|
||||||
class TaskDataset(pl.LightningDataModule):
|
class TaskDataset(pl.LightningDataModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
|
valid_size: float = 0.20,
|
||||||
duration: float = 1.0,
|
duration: float = 1.0,
|
||||||
sampling_rate: int = 48000,
|
sampling_rate: int = 48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
|
|
@ -60,8 +73,15 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
if valid_size > 0.0:
|
||||||
|
self.valid_size = valid_size
|
||||||
|
else:
|
||||||
|
raise ValueError("valid_size must be greater than 0")
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
prepare train/validation/test data splits
|
||||||
|
"""
|
||||||
|
|
||||||
if stage in ("fit", None):
|
if stage in ("fit", None):
|
||||||
|
|
||||||
|
|
@ -70,25 +90,33 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
fp = Fileprocessor.from_name(
|
fp = Fileprocessor.from_name(
|
||||||
self.name, train_clean, train_noisy, self.matching_function
|
self.name, train_clean, train_noisy, self.matching_function
|
||||||
)
|
)
|
||||||
self.train_data = fp.prepare_matching_dict()
|
train_data = fp.prepare_matching_dict()
|
||||||
|
self.train_data, self.val_data = train_test_split(
|
||||||
val_clean = os.path.join(self.root_dir, self.files.test_clean)
|
train_data, test_size=0.20, shuffle=True, random_state=42
|
||||||
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
|
||||||
fp = Fileprocessor.from_name(
|
|
||||||
self.name, val_clean, val_noisy, self.matching_function
|
|
||||||
)
|
)
|
||||||
val_data = fp.prepare_matching_dict()
|
|
||||||
|
|
||||||
for item in val_data:
|
self._validation = self.prepare_mapstype(self.val_data)
|
||||||
|
|
||||||
|
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||||
|
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||||
|
fp = Fileprocessor.from_name(
|
||||||
|
self.name, test_clean, test_noisy, self.matching_function
|
||||||
|
)
|
||||||
|
test_data = fp.prepare_matching_dict()
|
||||||
|
self._test = self.prepare_mapstype(test_data)
|
||||||
|
|
||||||
|
def prepare_mapstype(self, data):
|
||||||
|
|
||||||
|
metadata = []
|
||||||
|
for item in data:
|
||||||
clean, noisy, total_dur = item.values()
|
clean, noisy, total_dur = item.values()
|
||||||
if total_dur < self.duration:
|
if total_dur < self.duration:
|
||||||
continue
|
continue
|
||||||
num_segments = round(total_dur / self.duration)
|
num_segments = round(total_dur / self.duration)
|
||||||
for index in range(num_segments):
|
for index in range(num_segments):
|
||||||
start_time = index * self.duration
|
start_time = index * self.duration
|
||||||
self._validation.append(
|
metadata.append(({"clean": clean, "noisy": noisy}, start_time))
|
||||||
({"clean": clean, "noisy": noisy}, start_time)
|
return metadata
|
||||||
)
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
|
|
@ -104,6 +132,13 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return DataLoader(
|
||||||
|
TestDataset(self),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class EnhancerDataset(TaskDataset):
|
||||||
"""
|
"""
|
||||||
|
|
@ -137,6 +172,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
|
valid_size=0.2,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
|
|
@ -148,6 +184,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name=name,
|
name=name,
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
files=files,
|
files=files,
|
||||||
|
valid_size=valid_size,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
|
|
@ -183,6 +220,9 @@ class EnhancerDataset(TaskDataset):
|
||||||
def val__getitem__(self, idx):
|
def val__getitem__(self, idx):
|
||||||
return self.prepare_segment(*self._validation[idx])
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
||||||
|
def test__getitem__(self, idx):
|
||||||
|
return self.prepare_segment(*self._test[idx])
|
||||||
|
|
||||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||||
|
|
||||||
clean_segment = self.audio(
|
clean_segment = self.audio(
|
||||||
|
|
@ -218,3 +258,6 @@ class EnhancerDataset(TaskDataset):
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
||||||
|
def test__len__(self):
|
||||||
|
return len(self._test)
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class ProcessorFunctions:
|
||||||
One clean audio have multiple noisy audio files
|
One clean audio have multiple noisy audio files
|
||||||
"""
|
"""
|
||||||
|
|
||||||
matching_wavfiles = dict()
|
matching_wavfiles = list()
|
||||||
clean_filenames = [
|
clean_filenames = [
|
||||||
file.split("/")[-1]
|
file.split("/")[-1]
|
||||||
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||||
|
|
@ -73,7 +73,7 @@ class ProcessorFunctions:
|
||||||
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||||
sr_clean == sr_noisy
|
sr_clean == sr_noisy
|
||||||
):
|
):
|
||||||
matching_wavfiles.update(
|
matching_wavfiles.append(
|
||||||
{
|
{
|
||||||
"clean": os.path.join(clean_path, clean_file),
|
"clean": os.path.join(clean_path, clean_file),
|
||||||
"noisy": noisy_file,
|
"noisy": noisy_file,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue