diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 106f649..5cc9b31 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,6 +1,5 @@ import glob import os -from re import S import numpy as np from scipy.io import wavfile diff --git a/enhancer/loss.py b/enhancer/loss.py index f2f62d3..db1d222 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -17,8 +17,8 @@ class mean_squared_error(nn.Module): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( - f"""Inputs must be of the same shape (batch_size,channels,samples) - got {prediction.size()} and {target.size()} instead""" + f"""Inputs must be of the same shape (batch_size,channels,samples) + got {prediction.size()} and {target.size()} instead""" ) return self.loss_fun(prediction, target) @@ -39,7 +39,7 @@ class mean_absolute_error(nn.Module): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( - f"""Inputs must be of the same shape (batch_size,channels,samples) + f"""Inputs must be of the same shape (batch_size,channels,samples) got {prediction.size()} and {target.size()} instead""" ) @@ -65,7 +65,7 @@ class Si_SDR(nn.Module): if prediction.size() != target.size() or target.ndim < 3: raise TypeError( - f"""Inputs must be of the same shape (batch_size,channels,samples) + f"""Inputs must be of the same shape (batch_size,channels,samples) got {prediction.size()} and {target.size()} instead""" ) @@ -119,7 +119,8 @@ class Avergeloss(nn.Module): def validate_loss(self, loss: str): if loss not in LOSS_MAP.keys(): raise ValueError( - f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}" + f"""Invalid loss function {loss}, available loss functions are + {tuple([loss for loss in LOSS_MAP.keys()])}""" ) else: return LOSS_MAP[loss] diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 071bbb6..56f24db 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,16 +1,10 @@ -try: - from functools import cached_property -except ImportError: - from backports.cached_property import cached_property from importlib import import_module from huggingface_hub import cached_download, hf_hub_url -import logging import numpy as np import os from typing import Optional, Union, List, Text, Dict, Any from torch.optim import Adam import torch -from torch.nn.functional import pad import pytorch_lightning as pl from pytorch_lightning.utilities.cloud_io import load as pl_load from urllib.parse import urlparse @@ -19,7 +13,6 @@ from pathlib import Path from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset -from enhancer.utils.io import Audio from enhancer.loss import Avergeloss from enhancer.inference import Inference @@ -300,7 +293,7 @@ class Model(pl.LightningModule): with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + batch_data = batch[batch_id: batch_id + batch_size, :, :].to( self.device ) prediction = self(batch_data) diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index 73673ed..ebb41b4 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -19,6 +19,7 @@ def check_files(root_dir: str, files: Files): def merge_dict(default_dict: dict, custom: Optional[dict] = None): + params = dict(default_dict) if custom: params.update(custom)