From 6a3c67fc13c8d11ef00e75eb79e47e79d7eb9bf9 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 14 Oct 2022 11:32:18 +0530 Subject: [PATCH] print train/val duration --- enhancer/models/model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 4f055b4..714b0e5 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -113,6 +113,20 @@ class Model(pl.LightningModule): if stage == "fit": torch.cuda.empty_cache() self.dataset.setup(stage) + print( + "Total train duration", + self.dataset.train_dataloader().__len__() + * self.dataset.duration + / 60, + "minutes", + ) + print( + "Total validation duration", + self.dataset.val_dataloader().__len__() + * self.dataset.duration + / 60, + "minutes", + ) self.dataset.model = self def train_dataloader(self):