diff --git a/enhancer/models/model.py b/enhancer/models/model.py index cbbfad8..e7c5879 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url from pytorch_lightning.utilities.cloud_io import load as pl_load from torch.optim import Adam -from enhancer import __version__ from enhancer.data.dataset import EnhancerDataset from enhancer.inference import Inference from enhancer.loss import Avergeloss +from enhancer.version import __version__ CACHE_DIR = "" HF_TORCH_WEIGHTS = "" @@ -298,10 +298,9 @@ class Model(pl.LightningModule): ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] self.eval().to(self.device) - with torch.no_grad(): for batch_id in range(0, batch.shape[0], batch_size): - batch_data = batch[batch_id : batch_id + batch_size, :, :].to( + batch_data = batch[batch_id : (batch_id + batch_size), :, :].to( self.device ) prediction = self(batch_data)