fix version

This commit is contained in:
shahules786 2022-10-06 20:51:18 +05:30
parent 975e8bc50e
commit 3084ffac19
1 changed files with 2 additions and 3 deletions

View File

@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch.optim import Adam
from enhancer import __version__
from enhancer.data.dataset import EnhancerDataset
from enhancer.inference import Inference
from enhancer.loss import Avergeloss
from enhancer.version import __version__
CACHE_DIR = ""
HF_TORCH_WEIGHTS = ""
@ -298,10 +298,9 @@ class Model(pl.LightningModule):
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = []
self.eval().to(self.device)
with torch.no_grad():
for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
self.device
)
prediction = self(batch_data)