print train/val duration

This commit is contained in:
shahules786 2022-10-14 11:32:18 +05:30
parent 204de08a9a
commit 6a3c67fc13
1 changed files with 14 additions and 0 deletions

View File

@ -113,6 +113,20 @@ class Model(pl.LightningModule):
if stage == "fit": if stage == "fit":
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.dataset.setup(stage) self.dataset.setup(stage)
print(
"Total train duration",
self.dataset.train_dataloader().__len__()
* self.dataset.duration
/ 60,
"minutes",
)
print(
"Total validation duration",
self.dataset.val_dataloader().__len__()
* self.dataset.duration
/ 60,
"minutes",
)
self.dataset.model = self self.dataset.model = self
def train_dataloader(self): def train_dataloader(self):