diff --git a/enhancer/models/model.py b/enhancer/models/model.py index c4be077..980c583 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,5 +1,6 @@ from importlib import import_module from huggingface_hub import cached_download, hf_hub_url +import logging import numpy as np import os from typing import Optional, Union, List, Text, Dict, Any @@ -37,6 +38,9 @@ class Model(pl.LightningModule): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset + if self.dataset is not None: + sampling_rate = self.dataset.sampling_rate + logging.warn("Setting model sampling rate same as dataset sampling rate") self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")