From 2a6f310ba42194f2012d57577f24e102a99c45b8 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 24 Aug 2022 19:47:10 +0530 Subject: [PATCH] dataset --- enhancer/data/dataset.py | 34 ++++++++++++++++++++++++++++++++++ enhancer/data/vctk.py | 16 ---------------- 2 files changed, 34 insertions(+), 16 deletions(-) create mode 100644 enhancer/data/dataset.py diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py new file mode 100644 index 0000000..6b44001 --- /dev/null +++ b/enhancer/data/dataset.py @@ -0,0 +1,34 @@ +import os +import pytorch_lightning as pl +from typing import Optional + +from enhancer.data.vctk import Vctk +from enhancer.utils.config import Files + +DATASETS = ["Vctk"] + +class Dataset(pl.LightningDataModule): + + def __init__(self,name:str, directory:str, files:Files, + duration:float=1.0, sampling_rate:int=48000): + super().__init__() + + self.train_clean = os.path.join(directory,Files.train_clean) + self.train_noisy = os.path.join(directory,Files.train_noisy) + self.valid_clean = os.path.join(directory,Files.test_clean) + self.valid_noisy = os.path.join(directory,Files.test_noisy) + + if name.title() in DATASETS: + self.data_obj = eval(name.title) + + self.duration = duration + self.sampling_rate = sampling_rate + + def setup(self, stage: Optional[str] = None): + self.train_dataset = self.data_obj() + + def train_loader(self): + pass + + def valid_loader(self): + pass diff --git a/enhancer/data/vctk.py b/enhancer/data/vctk.py index 281ac2c..90cd87a 100644 --- a/enhancer/data/vctk.py +++ b/enhancer/data/vctk.py @@ -12,22 +12,6 @@ from enhancer.utils.io import Audio -class VctkDataset: - - def __init__(self): - pass - - def train_loader(self): - pass - - def valid_loader(self): - pass - - def test_loader(self): - pass - - - class Vctk(IterableDataset): """Dataset object for Voice Bank Corpus (VCTK) Dataset"""