From 2f85f48d69ce6d12edf9d6947123ffd0acbb8257 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 2 Nov 2022 17:57:30 +0530 Subject: [PATCH] add support for custom loss --- enhancer/models/model.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 3b60b85..92d30ae 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -2,7 +2,7 @@ import os from collections import defaultdict from importlib import import_module from pathlib import Path -from typing import List, Optional, Text, Union +from typing import Any, List, Optional, Text, Union from urllib.parse import urlparse import numpy as np @@ -10,6 +10,7 @@ import pytorch_lightning as pl import torch from huggingface_hub import cached_download, hf_hub_url from pytorch_lightning.utilities.cloud_io import load as pl_load +from torch import nn from torch.optim import Adam from enhancer.data.dataset import EnhancerDataset @@ -49,7 +50,7 @@ class Model(pl.LightningModule): dataset: Optional[EnhancerDataset] = None, duration: Optional[float] = None, loss: Union[str, List] = "mse", - metric: Union[str, List] = "mse", + metric: Union[str, List, Any] = "mse", ): super().__init__() assert ( @@ -86,20 +87,25 @@ class Model(pl.LightningModule): @metric.setter def metric(self, metric): self._metric = [] - if isinstance(metric, str): + if isinstance(metric, (str, nn.Module)): metric = [metric] for func in metric: - if func in LOSS_MAP.keys(): - if func in ("pesq", "stoi"): - self._metric.append( - LOSS_MAP[func](self.hparams.sampling_rate) - ) + if isinstance(func, str): + if func in LOSS_MAP.keys(): + if func in ("pesq", "stoi"): + self._metric.append( + LOSS_MAP[func](self.hparams.sampling_rate) + ) + else: + self._metric.append(LOSS_MAP[func]()) else: - self._metric.append(LOSS_MAP[func]()) + ValueError(f"Invalid metrics {func}") + elif isinstance(func, nn.Module): + self._metric.append(func) else: - raise ValueError(f"Invalid metrics {func}") + raise ValueError("Invalid metrics") @property def dataset(self):