models
This commit is contained in:
parent
d20b7a166f
commit
e5d9eb7e95
|
|
@ -1,3 +1,3 @@
|
|||
from enhancer.models.demucs import Demucs
|
||||
from enhancer.models.waveunet import WaveUnet
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.models.waveunet import WaveUnet
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import logging
|
||||
from typing import Optional, Union, List
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.utils.io import Audio as audio
|
||||
from enhancer.utils.utils import merge_dict
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,20 @@
|
|||
from importlib import import_module
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import Optional, Union, List, Text, Dict, Any
|
||||
from torch.optim import Adam
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from urllib.parse import urlparse
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Text, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from torch.optim import Adam
|
||||
|
||||
from enhancer import __version__
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.loss import Avergeloss
|
||||
from enhancer.inference import Inference
|
||||
from enhancer.loss import Avergeloss
|
||||
|
||||
CACHE_DIR = ""
|
||||
HF_TORCH_WEIGHTS = ""
|
||||
|
|
@ -293,7 +293,7 @@ class Model(pl.LightningModule):
|
|||
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0, batch.shape[0], batch_size):
|
||||
batch_data = batch[batch_id: batch_id + batch_size, :, :].to(
|
||||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||
self.device
|
||||
)
|
||||
prediction = self(batch_data)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Union, List
|
||||
|
||||
from enhancer.models.model import Model
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.models.model import Model
|
||||
|
||||
|
||||
class WavenetDecoder(nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue