From 619d0be8ce87291f3f4feb5829a5f6221db6d8b4 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Sat, 10 Sep 2022 11:42:29 +0530 Subject: [PATCH] set training step --- enhancer/models/model.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index a5b305f..f8dd502 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import Optional, Union, List from torch.optim import Adam import pytorch_lightning as pl from enhancer.data.dataset import Dataset +from enhancer.utils.loss import LOSS_MAP, Avergeloss class Model(pl.LightningModule): @@ -13,10 +14,11 @@ class Model(pl.LightningModule): sampling_rate:int=16000, lr:float=1e-3, dataset:Optional[Dataset]=None, + loss: Union[str, List] = "mse" ): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" - self.save_hyperparameters("num_channels","sampling_rate","lr") + self.save_hyperparameters("num_channels","sampling_rate","lr","loss") self.dataset = dataset @@ -31,8 +33,16 @@ class Model(pl.LightningModule): def setup(self,stage:Optional[str]=None): if stage == "fit": self.dataset.setup(stage) - self.dataset.model = self + self.dataset.model = self + self.setup_loss() + def setup_loss(self): + + loss = self.hparams.loss + if isinstance(loss,str): + losses = [loss] + + self.loss = Avergeloss(losses) def train_dataloader(self): return self.dataset.train_dataloader() @@ -44,7 +54,15 @@ class Model(pl.LightningModule): return Adam(self.parameters, lr = self.hparams.lr) def training_step(self,batch, batch_idx:int): - pass + + mixed_waveform = batch["noisy"] + target = batch["clean"] + prediction = self(mixed_waveform) + + loss = self.loss(prediction, target) + + return {"loss":loss} + @classmethod def from_pretrained(cls,):