model assigment'
This commit is contained in:
parent
40e2d6e0b0
commit
ea5c78798a
|
|
@ -113,6 +113,8 @@ class Model(pl.LightningModule):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
|
self.dataset.model = self
|
||||||
|
|
||||||
print(
|
print(
|
||||||
"Total train duration",
|
"Total train duration",
|
||||||
self.dataset.train_dataloader().dataset.__len__()
|
self.dataset.train_dataloader().dataset.__len__()
|
||||||
|
|
@ -134,7 +136,6 @@ class Model(pl.LightningModule):
|
||||||
/ 60,
|
/ 60,
|
||||||
"minutes",
|
"minutes",
|
||||||
)
|
)
|
||||||
self.dataset.model = self
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue