From 3084ffac198edeb64c19065f22c06192c088675d Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 6 Oct 2022 20:51:18 +0530 Subject: [PATCH] fix version --- enhancer/models/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index cbbfad8..e7c5879 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url from pytorch_lightning.utilities.cloud_io import load as pl_load from torch.optim import Adam -from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset from enhancer.inference import Inference from enhancer.loss import Avergeloss +from enhancer.version import __version__ CACHE_DIR = "" HF_TORCH_WEIGHTS = "" @@ -298,10 +298,9 @@ class Model(pl.LightningModule): ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] self.eval().to(self.device) - with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + batch_data = batch[batch_id : (batch_id + batch_size), :, :].to( self.device ) prediction = self(batch_data)