add root dir

This commit is contained in:
shahules786 2022-09-01 09:46:38 +05:30
parent a939b4c37d
commit 7353fdafc2
1 changed files with 7 additions and 7 deletions

View File

@ -65,14 +65,14 @@ class EnhancerDataset(IterableDataset):
class Dataset(pl.LightningDataModule):
def __init__(self,name:str, files:Files,
def __init__(self,name:str,root_dir:str, files:Files,
duration:float=1.0, sampling_rate:int=48000, batch_size=32):
super().__init__()
self.train_clean = files.train_clean
self.train_noisy = files.train_noisy
self.valid_clean = files.test_clean
self.valid_noisy = files.test_noisy
self.train_clean = os.path.join(root_dir, files.train_clean)
self.train_noisy = os.path.join(root_dir,files.train_noisy)
self.valid_clean = os.path.join(root_dir,files.test_clean)
self.valid_noisy = os.path.join(root_dir,files.test_noisy)
self.name = name
self.duration = duration
self.sampling_rate = sampling_rate
@ -87,9 +87,9 @@ class Dataset(pl.LightningDataModule):
self.valid_dataset = EnhancerDataset(self.name, self.valid_clean,
self.valid_noisy, self.duration, self.sampling_rate)
def train_loader(self):
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size = self.batch_size)
def valid_loader(self):
def valid_dataloader(self):
return DataLoader(self.valid_dataset, batch_size = self.batch_size)