From a1103310f2c814dee83cfd127bcfc053318bf83a Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 12 Sep 2022 11:33:34 +0530 Subject: [PATCH] validation step --- enhancer/models/model.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index bdc49dd..2adfef8 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -14,11 +14,12 @@ class Model(pl.LightningModule): sampling_rate:int=16000, lr:float=1e-3, dataset:Optional[Dataset]=None, - loss: Union[str, List] = "mse" + loss: Union[str, List] = "mse", + metric:Union[str,List] = "mse" ): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" - self.save_hyperparameters("num_channels","sampling_rate","lr","loss") + self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric") self.dataset = dataset @@ -34,15 +35,15 @@ class Model(pl.LightningModule): if stage == "fit": self.dataset.setup(stage) self.dataset.model = self - self.setup_loss() + self.loss = self.setup_loss(self.hparams.loss) + self.metric = self.setup_loss(self.hparams.metric) - def setup_loss(self): + def setup_loss(self,loss): - loss = self.hparams.loss if isinstance(loss,str): losses = [loss] - self.loss = Avergeloss(losses) + return Avergeloss(losses) def train_dataloader(self): return self.dataset.train_dataloader() @@ -63,6 +64,15 @@ class Model(pl.LightningModule): return {"loss":loss} + def validation_step(self,batch,batch_idx:int): + + mixed_waveform = batch["noisy"] + target = batch["clean"] + prediction = self(mixed_waveform) + + loss = self.metric(prediction, target) + + return {"loss":loss} @classmethod def from_pretrained(cls,):