diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 87fab4f..640d090 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -2,13 +2,14 @@ from importlib import import_module from huggingface_hub import cached_download, hf_hub_url import numpy as np import os -from typing import Optional, Union, List, Path, Text, Dict, Any +from typing import Optional, Union, List, Text, Dict, Any from torch.optim import Adam import torch from torch.nn.functional import pad import pytorch_lightning as pl from pytorch_lightning.utilities.cloud_io import load as pl_load from urllib.parse import urlparse +from pathlib import Path from enhancer import __version__ @@ -29,13 +30,14 @@ class Model(pl.LightningModule): sampling_rate:int=16000, lr:float=1e-3, dataset:Optional[Dataset]=None, + duration:Optional[float]=None, loss: Union[str, List] = "mse", metric:Union[str,List] = "mse" ): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset - self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric") + self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") @property @@ -44,8 +46,6 @@ class Model(pl.LightningModule): @dataset.setter def dataset(self,dataset): - if dataset is not None: - self.save_hyperparameters("duration",self.dataset.duration) self._dataset = dataset def setup(self,stage:Optional[str]=None):