add root dir

This commit is contained in:
shahules786 2022-09-01 09:46:38 +05:30
parent a939b4c37d
commit 7353fdafc2
1 changed files with 7 additions and 7 deletions

View File

@ -65,14 +65,14 @@ class EnhancerDataset(IterableDataset):
class Dataset(pl.LightningDataModule): class Dataset(pl.LightningDataModule):
def __init__(self,name:str, files:Files, def __init__(self,name:str,root_dir:str, files:Files,
duration:float=1.0, sampling_rate:int=48000, batch_size=32): duration:float=1.0, sampling_rate:int=48000, batch_size=32):
super().__init__() super().__init__()
self.train_clean = files.train_clean self.train_clean = os.path.join(root_dir, files.train_clean)
self.train_noisy = files.train_noisy self.train_noisy = os.path.join(root_dir,files.train_noisy)
self.valid_clean = files.test_clean self.valid_clean = os.path.join(root_dir,files.test_clean)
self.valid_noisy = files.test_noisy self.valid_noisy = os.path.join(root_dir,files.test_noisy)
self.name = name self.name = name
self.duration = duration self.duration = duration
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
@ -87,9 +87,9 @@ class Dataset(pl.LightningDataModule):
self.valid_dataset = EnhancerDataset(self.name, self.valid_clean, self.valid_dataset = EnhancerDataset(self.name, self.valid_clean,
self.valid_noisy, self.duration, self.sampling_rate) self.valid_noisy, self.duration, self.sampling_rate)
def train_loader(self): def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size = self.batch_size) return DataLoader(self.train_dataset, batch_size = self.batch_size)
def valid_loader(self): def valid_dataloader(self):
return DataLoader(self.valid_dataset, batch_size = self.batch_size) return DataLoader(self.valid_dataset, batch_size = self.batch_size)