add root dir
This commit is contained in:
parent
a939b4c37d
commit
7353fdafc2
|
|
@ -65,14 +65,14 @@ class EnhancerDataset(IterableDataset):
|
|||
|
||||
class Dataset(pl.LightningDataModule):
|
||||
|
||||
def __init__(self,name:str, files:Files,
|
||||
def __init__(self,name:str,root_dir:str, files:Files,
|
||||
duration:float=1.0, sampling_rate:int=48000, batch_size=32):
|
||||
super().__init__()
|
||||
|
||||
self.train_clean = files.train_clean
|
||||
self.train_noisy = files.train_noisy
|
||||
self.valid_clean = files.test_clean
|
||||
self.valid_noisy = files.test_noisy
|
||||
self.train_clean = os.path.join(root_dir, files.train_clean)
|
||||
self.train_noisy = os.path.join(root_dir,files.train_noisy)
|
||||
self.valid_clean = os.path.join(root_dir,files.test_clean)
|
||||
self.valid_noisy = os.path.join(root_dir,files.test_noisy)
|
||||
self.name = name
|
||||
self.duration = duration
|
||||
self.sampling_rate = sampling_rate
|
||||
|
|
@ -87,9 +87,9 @@ class Dataset(pl.LightningDataModule):
|
|||
self.valid_dataset = EnhancerDataset(self.name, self.valid_clean,
|
||||
self.valid_noisy, self.duration, self.sampling_rate)
|
||||
|
||||
def train_loader(self):
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.train_dataset, batch_size = self.batch_size)
|
||||
|
||||
|
||||
def valid_loader(self):
|
||||
def valid_dataloader(self):
|
||||
return DataLoader(self.valid_dataset, batch_size = self.batch_size)
|
||||
Loading…
Reference in New Issue