fix model sr to dataset sr
This commit is contained in:
		
							parent
							
								
									18759a3f84
								
							
						
					
					
						commit
						e22cecaf20
					
				|  | @ -1,5 +1,6 @@ | ||||||
| from importlib import import_module | from importlib import import_module | ||||||
| from huggingface_hub import cached_download, hf_hub_url | from huggingface_hub import cached_download, hf_hub_url | ||||||
|  | import logging | ||||||
| import numpy as np | import numpy as np | ||||||
| import os | import os | ||||||
| from typing import Optional, Union, List, Text, Dict, Any | from typing import Optional, Union, List, Text, Dict, Any | ||||||
|  | @ -37,6 +38,9 @@ class Model(pl.LightningModule): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         assert num_channels ==1 , "Enhancer only support for mono channel models" |         assert num_channels ==1 , "Enhancer only support for mono channel models" | ||||||
|         self.dataset = dataset |         self.dataset = dataset | ||||||
|  |         if self.dataset is not None: | ||||||
|  |             sampling_rate = self.dataset.sampling_rate | ||||||
|  |             logging.warn("Setting model sampling rate same as dataset sampling rate") | ||||||
|         self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") |         self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") | ||||||
|         if self.logger: |         if self.logger: | ||||||
|             self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") |             self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786