fix imports

This commit is contained in:
shahules786 2022-09-23 18:21:54 +05:30
parent 24c7a6f1f0
commit 7641e5107c
1 changed files with 4 additions and 4 deletions

View File

@ -2,13 +2,14 @@ from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
import numpy as np import numpy as np
import os import os
from typing import Optional, Union, List, Path, Text, Dict, Any from typing import Optional, Union, List, Text, Dict, Any
from torch.optim import Adam from torch.optim import Adam
import torch import torch
from torch.nn.functional import pad from torch.nn.functional import pad
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from urllib.parse import urlparse from urllib.parse import urlparse
from pathlib import Path
from enhancer import __version__ from enhancer import __version__
@ -29,13 +30,14 @@ class Model(pl.LightningModule):
sampling_rate:int=16000, sampling_rate:int=16000,
lr:float=1e-3, lr:float=1e-3,
dataset:Optional[Dataset]=None, dataset:Optional[Dataset]=None,
duration:Optional[float]=None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse" metric:Union[str,List] = "mse"
): ):
super().__init__() super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models" assert num_channels ==1 , "Enhancer only support for mono channel models"
self.dataset = dataset self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric") self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
@property @property
@ -44,8 +46,6 @@ class Model(pl.LightningModule):
@dataset.setter @dataset.setter
def dataset(self,dataset): def dataset(self,dataset):
if dataset is not None:
self.save_hyperparameters("duration",self.dataset.duration)
self._dataset = dataset self._dataset = dataset
def setup(self,stage:Optional[str]=None): def setup(self,stage:Optional[str]=None):