configure base class

This commit is contained in:
shahules786 2022-09-08 11:35:10 +05:30
parent 8230061b7b
commit 4cb417cdbe
1 changed files with 20 additions and 12 deletions

View File

@ -1,4 +1,5 @@
from typing import Optional
from torch.optim import Adam
import pytorch_lightning as pl
from enhancer.data.dataset import Dataset
@ -8,12 +9,16 @@ class Model(pl.LightningModule):
def __init__(
self,
dataset:Dataset
num_channels:int=1,
sampling_rate:int=16000,
lr:float=1e-3,
dataset:Optional[Dataset]=None,
):
super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models"
self.save_hyperparameters("num_channels","sampling_rate","lr")
self.dataset = dataset
pass
@property
def dataset(self):
@ -23,21 +28,24 @@ class Model(pl.LightningModule):
def dataset(self,dataset):
self._dataset = dataset
def setup(
self,
stage:Optional[str]=None
):
def setup(self,stage:Optional[str]=None):
if stage == "fit":
self.dataset.setup(stage)
self.dataset.model = self
def train_dataloader(
self
):
def train_dataloader(self):
return self.dataset.train_dataloader()
def val_dataloader(
self
):
def val_dataloader(self):
return self.dataset.val_dataloader()
def configure_optimizers(self):
return Adam(self.parameters, lr = self.hparams.lr)
def training_step(self,batch, batch_idx:int):
pass
@classmethod
def from_pretrained(cls,):
pass