diff --git a/enhancer/models/model.py b/enhancer/models/model.py index c7bfe30..a5b305f 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,4 +1,5 @@ from typing import Optional +from torch.optim import Adam import pytorch_lightning as pl from enhancer.data.dataset import Dataset @@ -8,12 +9,16 @@ class Model(pl.LightningModule): def __init__( self, - dataset:Dataset + num_channels:int=1, + sampling_rate:int=16000, + lr:float=1e-3, + dataset:Optional[Dataset]=None, ): super().__init__() + assert num_channels ==1 , "Enhancer only support for mono channel models" + self.save_hyperparameters("num_channels","sampling_rate","lr") self.dataset = dataset - pass @property def dataset(self): @@ -23,21 +28,24 @@ class Model(pl.LightningModule): def dataset(self,dataset): self._dataset = dataset - def setup( - self, - stage:Optional[str]=None - ): + def setup(self,stage:Optional[str]=None): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - def train_dataloader( - self - ): + def train_dataloader(self): return self.dataset.train_dataloader() - def val_dataloader( - self - ): + def val_dataloader(self): return self.dataset.val_dataloader() + + def configure_optimizers(self): + return Adam(self.parameters, lr = self.hparams.lr) + + def training_step(self,batch, batch_idx:int): + pass + + @classmethod + def from_pretrained(cls,): + pass \ No newline at end of file