configure base class
This commit is contained in:
parent
8230061b7b
commit
4cb417cdbe
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from torch.optim import Adam
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
from enhancer.data.dataset import Dataset
|
from enhancer.data.dataset import Dataset
|
||||||
|
|
@ -8,12 +9,16 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset:Dataset
|
num_channels:int=1,
|
||||||
|
sampling_rate:int=16000,
|
||||||
|
lr:float=1e-3,
|
||||||
|
dataset:Optional[Dataset]=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||||
|
self.save_hyperparameters("num_channels","sampling_rate","lr")
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
|
|
@ -23,21 +28,24 @@ class Model(pl.LightningModule):
|
||||||
def dataset(self,dataset):
|
def dataset(self,dataset):
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
|
|
||||||
def setup(
|
def setup(self,stage:Optional[str]=None):
|
||||||
self,
|
|
||||||
stage:Optional[str]=None
|
|
||||||
):
|
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
self.dataset.model = self
|
self.dataset.model = self
|
||||||
|
|
||||||
|
|
||||||
def train_dataloader(
|
def train_dataloader(self):
|
||||||
self
|
|
||||||
):
|
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
||||||
def val_dataloader(
|
def val_dataloader(self):
|
||||||
self
|
|
||||||
):
|
|
||||||
return self.dataset.val_dataloader()
|
return self.dataset.val_dataloader()
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return Adam(self.parameters, lr = self.hparams.lr)
|
||||||
|
|
||||||
|
def training_step(self,batch, batch_idx:int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls,):
|
||||||
|
pass
|
||||||
Loading…
Reference in New Issue