From 7353fdafc2a97d9a8f5295275636688e82be9c64 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 1 Sep 2022 09:46:38 +0530 Subject: [PATCH] add root dir --- enhancer/data/dataset.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 4bc5b3b..2807aed 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -65,14 +65,14 @@ class EnhancerDataset(IterableDataset): class Dataset(pl.LightningDataModule): - def __init__(self,name:str, files:Files, + def __init__(self,name:str,root_dir:str, files:Files, duration:float=1.0, sampling_rate:int=48000, batch_size=32): super().__init__() - self.train_clean = files.train_clean - self.train_noisy = files.train_noisy - self.valid_clean = files.test_clean - self.valid_noisy = files.test_noisy + self.train_clean = os.path.join(root_dir, files.train_clean) + self.train_noisy = os.path.join(root_dir,files.train_noisy) + self.valid_clean = os.path.join(root_dir,files.test_clean) + self.valid_noisy = os.path.join(root_dir,files.test_noisy) self.name = name self.duration = duration self.sampling_rate = sampling_rate @@ -87,9 +87,9 @@ class Dataset(pl.LightningDataModule): self.valid_dataset = EnhancerDataset(self.name, self.valid_clean, self.valid_noisy, self.duration, self.sampling_rate) - def train_loader(self): + def train_dataloader(self): return DataLoader(self.train_dataset, batch_size = self.batch_size) - def valid_loader(self): + def valid_dataloader(self): return DataLoader(self.valid_dataset, batch_size = self.batch_size) \ No newline at end of file