add root dir
This commit is contained in:
parent
a939b4c37d
commit
7353fdafc2
|
|
@ -65,14 +65,14 @@ class EnhancerDataset(IterableDataset):
|
||||||
|
|
||||||
class Dataset(pl.LightningDataModule):
|
class Dataset(pl.LightningDataModule):
|
||||||
|
|
||||||
def __init__(self,name:str, files:Files,
|
def __init__(self,name:str,root_dir:str, files:Files,
|
||||||
duration:float=1.0, sampling_rate:int=48000, batch_size=32):
|
duration:float=1.0, sampling_rate:int=48000, batch_size=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.train_clean = files.train_clean
|
self.train_clean = os.path.join(root_dir, files.train_clean)
|
||||||
self.train_noisy = files.train_noisy
|
self.train_noisy = os.path.join(root_dir,files.train_noisy)
|
||||||
self.valid_clean = files.test_clean
|
self.valid_clean = os.path.join(root_dir,files.test_clean)
|
||||||
self.valid_noisy = files.test_noisy
|
self.valid_noisy = os.path.join(root_dir,files.test_noisy)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
|
@ -87,9 +87,9 @@ class Dataset(pl.LightningDataModule):
|
||||||
self.valid_dataset = EnhancerDataset(self.name, self.valid_clean,
|
self.valid_dataset = EnhancerDataset(self.name, self.valid_clean,
|
||||||
self.valid_noisy, self.duration, self.sampling_rate)
|
self.valid_noisy, self.duration, self.sampling_rate)
|
||||||
|
|
||||||
def train_loader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(self.train_dataset, batch_size = self.batch_size)
|
return DataLoader(self.train_dataset, batch_size = self.batch_size)
|
||||||
|
|
||||||
|
|
||||||
def valid_loader(self):
|
def valid_dataloader(self):
|
||||||
return DataLoader(self.valid_dataset, batch_size = self.batch_size)
|
return DataLoader(self.valid_dataset, batch_size = self.batch_size)
|
||||||
Loading…
Reference in New Issue