44 lines
		
	
	
		
			795 B
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			44 lines
		
	
	
		
			795 B
		
	
	
	
		
			Python
		
	
	
	
| from typing import Optional
 | |
| import pytorch_lightning as pl
 | |
| 
 | |
| from enhancer.data.dataset import Dataset
 | |
| 
 | |
| 
 | |
| class Model(pl.LightningModule):
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         dataset:Dataset
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.dataset = dataset
 | |
| 
 | |
|         pass
 | |
|     
 | |
|     @property
 | |
|     def dataset(self):
 | |
|         return self._dataset
 | |
| 
 | |
|     @dataset.setter
 | |
|     def dataset(self,dataset):
 | |
|         self._dataset = dataset
 | |
| 
 | |
|     def setup(
 | |
|         self,
 | |
|         stage:Optional[str]=None
 | |
|     ):
 | |
|         if stage == "fit":
 | |
|             self.dataset.setup(stage)
 | |
|             self.dataset.model = self 
 | |
|         
 | |
| 
 | |
|     def train_dataloader(
 | |
|         self
 | |
|     ):
 | |
|         return self.dataset.train_dataloader()
 | |
| 
 | |
|     def val_dataloader(
 | |
|         self
 | |
|     ):
 | |
|         return self.dataset.val_dataloader()
 |