From e5d9eb7e95737565066ddc3ee3ff7c73c6e36d88 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 20:35:33 +0530 Subject: [PATCH] models --- enhancer/models/__init__.py | 2 +- enhancer/models/demucs.py | 9 +++++---- enhancer/models/model.py | 22 +++++++++++----------- enhancer/models/waveunet.py | 5 +++-- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py index 368a9d7..2d97568 100644 --- a/enhancer/models/__init__.py +++ b/enhancer/models/__init__.py @@ -1,3 +1,3 @@ from enhancer.models.demucs import Demucs -from enhancer.models.waveunet import WaveUnet from enhancer.models.model import Model +from enhancer.models.waveunet import WaveUnet diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 76a0bf7..65f119d 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,11 +1,12 @@ import logging -from typing import Optional, Union, List -from torch import nn -import torch.nn.functional as F import math +from typing import List, Optional, Union + +import torch.nn.functional as F +from torch import nn -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model from enhancer.utils.io import Audio as audio from enhancer.utils.utils import merge_dict diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 56f24db..39dbe80 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,20 +1,20 @@ -from importlib import import_module -from huggingface_hub import cached_download, hf_hub_url -import numpy as np import os -from typing import Optional, Union, List, Text, Dict, Any -from torch.optim import Adam -import torch -import pytorch_lightning as pl -from pytorch_lightning.utilities.cloud_io import load as pl_load -from urllib.parse import urlparse +from importlib import import_module from pathlib import Path +from typing import Any, Dict, List, Optional, Text, Union +from urllib.parse import urlparse +import numpy as np +import pytorch_lightning as pl +import torch +from huggingface_hub import cached_download, hf_hub_url +from pytorch_lightning.utilities.cloud_io import load as pl_load +from torch.optim import Adam from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset -from enhancer.loss import Avergeloss from enhancer.inference import Inference +from enhancer.loss import Avergeloss CACHE_DIR = "" HF_TORCH_WEIGHTS = "" @@ -293,7 +293,7 @@ class Model(pl.LightningModule): with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id: batch_id + batch_size, :, :].to( + batch_data = batch[batch_id : batch_id + batch_size, :, :].to( self.device ) prediction = self(batch_data) diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 4d5cc0a..ebb4b1f 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,11 +1,12 @@ import logging +from typing import List, Optional, Union + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Union, List -from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset +from enhancer.models.model import Model class WavenetDecoder(nn.Module):