diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 640d090..5d1cd3e 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -13,7 +13,7 @@ from pathlib import Path from enhancer import __version__ -from enhancer.data.dataset import Dataset +from enhancer.data.dataset import EnhancerDataset from enhancer.utils.io import Audio from enhancer.utils.loss import Avergeloss from enhancer.inference import Inference @@ -29,7 +29,7 @@ class Model(pl.LightningModule): num_channels:int=1, sampling_rate:int=16000, lr:float=1e-3, - dataset:Optional[Dataset]=None, + dataset:Optional[EnhancerDataset]=None, duration:Optional[float]=None, loss: Union[str, List] = "mse", metric:Union[str,List] = "mse"