model assigment'

This commit is contained in:
shahules786 2022-10-23 12:33:38 +05:30
parent 40e2d6e0b0
commit ea5c78798a
1 changed files with 2 additions and 1 deletions

View File

@ -113,6 +113,8 @@ class Model(pl.LightningModule):
if stage == "fit": if stage == "fit":
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.dataset.setup(stage) self.dataset.setup(stage)
self.dataset.model = self
print( print(
"Total train duration", "Total train duration",
self.dataset.train_dataloader().dataset.__len__() self.dataset.train_dataloader().dataset.__len__()
@ -134,7 +136,6 @@ class Model(pl.LightningModule):
/ 60, / 60,
"minutes", "minutes",
) )
self.dataset.model = self
def train_dataloader(self): def train_dataloader(self):
return self.dataset.train_dataloader() return self.dataset.train_dataloader()