add root dir
This commit is contained in:
		
							parent
							
								
									a939b4c37d
								
							
						
					
					
						commit
						7353fdafc2
					
				|  | @ -65,14 +65,14 @@ class EnhancerDataset(IterableDataset): | |||
|          | ||||
| class Dataset(pl.LightningDataModule): | ||||
| 
 | ||||
|     def __init__(self,name:str, files:Files,  | ||||
|     def __init__(self,name:str,root_dir:str, files:Files,  | ||||
|                     duration:float=1.0, sampling_rate:int=48000, batch_size=32): | ||||
|         super().__init__() | ||||
| 
 | ||||
|         self.train_clean = files.train_clean | ||||
|         self.train_noisy = files.train_noisy  | ||||
|         self.valid_clean = files.test_clean | ||||
|         self.valid_noisy = files.test_noisy | ||||
|         self.train_clean = os.path.join(root_dir, files.train_clean) | ||||
|         self.train_noisy = os.path.join(root_dir,files.train_noisy)  | ||||
|         self.valid_clean = os.path.join(root_dir,files.test_clean) | ||||
|         self.valid_noisy = os.path.join(root_dir,files.test_noisy) | ||||
|         self.name = name | ||||
|         self.duration = duration | ||||
|         self.sampling_rate = sampling_rate | ||||
|  | @ -87,9 +87,9 @@ class Dataset(pl.LightningDataModule): | |||
|             self.valid_dataset = EnhancerDataset(self.name, self.valid_clean,  | ||||
|                                 self.valid_noisy, self.duration, self.sampling_rate) | ||||
| 
 | ||||
|     def train_loader(self): | ||||
|     def train_dataloader(self): | ||||
|         return DataLoader(self.train_dataset, batch_size = self.batch_size) | ||||
| 
 | ||||
| 
 | ||||
|     def valid_loader(self): | ||||
|     def valid_dataloader(self): | ||||
|         return DataLoader(self.valid_dataset, batch_size = self.batch_size) | ||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786