fix model sr to dataset sr
This commit is contained in:
parent
18759a3f84
commit
e22cecaf20
|
|
@ -1,5 +1,6 @@
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Union, List, Text, Dict, Any
|
from typing import Optional, Union, List, Text, Dict, Any
|
||||||
|
|
@ -37,6 +38,9 @@ class Model(pl.LightningModule):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
if self.dataset is not None:
|
||||||
|
sampling_rate = self.dataset.sampling_rate
|
||||||
|
logging.warn("Setting model sampling rate same as dataset sampling rate")
|
||||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue