Merge pull request #12 from shahules786/dev-reformat

Dev reformat
This commit is contained in:
Shahul ES 2022-10-05 15:49:16 +05:30 committed by GitHub
commit 324a060f01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 9 additions and 14 deletions

View File

@ -1,4 +1,5 @@
[flake8]
per-file-ignores = __init__.py:F401
ignore = E203, E266, E501, W503
# line length is intentionally set to 80 here because black uses Bugbear
# See https://github.com/psf/black/blob/master/README.md#line-length for more details

View File

@ -1,6 +1,5 @@
import glob
import os
from re import S
import numpy as np
from scipy.io import wavfile

View File

@ -17,8 +17,8 @@ class mean_squared_error(nn.Module):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
return self.loss_fun(prediction, target)
@ -39,7 +39,7 @@ class mean_absolute_error(nn.Module):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
@ -65,7 +65,7 @@ class Si_SDR(nn.Module):
if prediction.size() != target.size() or target.ndim < 3:
raise TypeError(
f"""Inputs must be of the same shape (batch_size,channels,samples)
f"""Inputs must be of the same shape (batch_size,channels,samples)
got {prediction.size()} and {target.size()} instead"""
)
@ -119,7 +119,8 @@ class Avergeloss(nn.Module):
def validate_loss(self, loss: str):
if loss not in LOSS_MAP.keys():
raise ValueError(
f"Invalid loss function {loss}, available loss functions are {tuple([loss for loss in LOSS_MAP.keys()])}"
f"""Invalid loss function {loss}, available loss functions are
{tuple([loss for loss in LOSS_MAP.keys()])}"""
)
else:
return LOSS_MAP[loss]

View File

@ -1,16 +1,10 @@
try:
from functools import cached_property
except ImportError:
from backports.cached_property import cached_property
from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import logging
import numpy as np
import os
from typing import Optional, Union, List, Text, Dict, Any
from torch.optim import Adam
import torch
from torch.nn.functional import pad
import pytorch_lightning as pl
from pytorch_lightning.utilities.cloud_io import load as pl_load
from urllib.parse import urlparse
@ -19,7 +13,6 @@ from pathlib import Path
from enhancer import __version__
from enhancer.data.dataset import EnhancerDataset
from enhancer.utils.io import Audio
from enhancer.loss import Avergeloss
from enhancer.inference import Inference
@ -300,7 +293,7 @@ class Model(pl.LightningModule):
with torch.no_grad():
for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
batch_data = batch[batch_id: batch_id + batch_size, :, :].to(
self.device
)
prediction = self(batch_data)

View File

@ -19,6 +19,7 @@ def check_files(root_dir: str, files: Files):
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
params = dict(default_dict)
if custom:
params.update(custom)