44 lines
795 B
Python
44 lines
795 B
Python
from typing import Optional
|
|
import pytorch_lightning as pl
|
|
|
|
from enhancer.data.dataset import Dataset
|
|
|
|
|
|
class Model(pl.LightningModule):
|
|
|
|
def __init__(
|
|
self,
|
|
dataset:Dataset
|
|
):
|
|
super().__init__()
|
|
self.dataset = dataset
|
|
|
|
pass
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self._dataset
|
|
|
|
@dataset.setter
|
|
def dataset(self,dataset):
|
|
self._dataset = dataset
|
|
|
|
def setup(
|
|
self,
|
|
stage:Optional[str]=None
|
|
):
|
|
if stage == "fit":
|
|
self.dataset.setup(stage)
|
|
self.dataset.model = self
|
|
|
|
|
|
def train_dataloader(
|
|
self
|
|
):
|
|
return self.dataset.train_dataloader()
|
|
|
|
def val_dataloader(
|
|
self
|
|
):
|
|
return self.dataset.val_dataloader()
|