fix imports

This commit is contained in:
shahules786 2022-09-23 18:21:54 +05:30
parent 24c7a6f1f0
commit 7641e5107c
1 changed files with 4 additions and 4 deletions

View File

@ -2,13 +2,14 @@ from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import numpy as np
import os
from typing import Optional, Union, List, Path, Text, Dict, Any
from typing import Optional, Union, List, Text, Dict, Any
from torch.optim import Adam
import torch
from torch.nn.functional import pad
import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load
from urllib.parse import urlparse
from pathlib import Path
from enhancer import __version__
@ -29,13 +30,14 @@ class Model(pl.LightningModule):
sampling_rate:int=16000,
lr:float=1e-3,
dataset:Optional[Dataset]=None,
duration:Optional[float]=None,
loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse"
):
super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models"
self.dataset = dataset
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric")
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration")
@property
@ -44,8 +46,6 @@ class Model(pl.LightningModule):
@dataset.setter
def dataset(self,dataset):
if dataset is not None:
self.save_hyperparameters("duration",self.dataset.duration)
self._dataset = dataset
def setup(self,stage:Optional[str]=None):