set training step

This commit is contained in:
shahules786 2022-09-10 11:42:29 +05:30
parent f8c8884ce9
commit 619d0be8ce
1 changed files with 22 additions and 4 deletions

View File

@ -1,8 +1,9 @@
from typing import Optional
from typing import Optional, Union, List
from torch.optim import Adam
import pytorch_lightning as pl
from enhancer.data.dataset import Dataset
from enhancer.utils.loss import LOSS_MAP, Avergeloss
class Model(pl.LightningModule):
@ -13,10 +14,11 @@ class Model(pl.LightningModule):
sampling_rate:int=16000,
lr:float=1e-3,
dataset:Optional[Dataset]=None,
loss: Union[str, List] = "mse"
):
super().__init__()
assert num_channels ==1 , "Enhancer only support for mono channel models"
self.save_hyperparameters("num_channels","sampling_rate","lr")
self.save_hyperparameters("num_channels","sampling_rate","lr","loss")
self.dataset = dataset
@ -32,7 +34,15 @@ class Model(pl.LightningModule):
if stage == "fit":
self.dataset.setup(stage)
self.dataset.model = self
self.setup_loss()
def setup_loss(self):
loss = self.hparams.loss
if isinstance(loss,str):
losses = [loss]
self.loss = Avergeloss(losses)
def train_dataloader(self):
return self.dataset.train_dataloader()
@ -44,7 +54,15 @@ class Model(pl.LightningModule):
return Adam(self.parameters, lr = self.hparams.lr)
def training_step(self,batch, batch_idx:int):
pass
mixed_waveform = batch["noisy"]
target = batch["clean"]
prediction = self(mixed_waveform)
loss = self.loss(prediction, target)
return {"loss":loss}
@classmethod
def from_pretrained(cls,):