fix version
This commit is contained in:
parent
975e8bc50e
commit
3084ffac19
|
|
@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url
|
|||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from torch.optim import Adam
|
||||
|
||||
from enhancer import __version__
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.inference import Inference
|
||||
from enhancer.loss import Avergeloss
|
||||
from enhancer.version import __version__
|
||||
|
||||
CACHE_DIR = ""
|
||||
HF_TORCH_WEIGHTS = ""
|
||||
|
|
@ -298,10 +298,9 @@ class Model(pl.LightningModule):
|
|||
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
batch_predictions = []
|
||||
self.eval().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0, batch.shape[0], batch_size):
|
||||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
|
||||
self.device
|
||||
)
|
||||
prediction = self(batch_data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue