add support for custom loss

This commit is contained in:
shahules786 2022-11-02 17:57:30 +05:30
parent 7f3dcf39c5
commit 2f85f48d69
1 changed files with 16 additions and 10 deletions

View File

@ -2,7 +2,7 @@ import os
from collections import defaultdict from collections import defaultdict
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from typing import List, Optional, Text, Union from typing import Any, List, Optional, Text, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@ -10,6 +10,7 @@ import pytorch_lightning as pl
import torch import torch
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
dataset: Optional[EnhancerDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List, Any] = "mse",
): ):
super().__init__() super().__init__()
assert ( assert (
@ -86,20 +87,25 @@ class Model(pl.LightningModule):
@metric.setter @metric.setter
def metric(self, metric): def metric(self, metric):
self._metric = [] self._metric = []
if isinstance(metric, str): if isinstance(metric, (str, nn.Module)):
metric = [metric] metric = [metric]
for func in metric: for func in metric:
if func in LOSS_MAP.keys(): if isinstance(func, str):
if func in ("pesq", "stoi"): if func in LOSS_MAP.keys():
self._metric.append( if func in ("pesq", "stoi"):
LOSS_MAP[func](self.hparams.sampling_rate) self._metric.append(
) LOSS_MAP[func](self.hparams.sampling_rate)
)
else:
self._metric.append(LOSS_MAP[func]())
else: else:
self._metric.append(LOSS_MAP[func]()) ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else: else:
raise ValueError(f"Invalid metrics {func}") raise ValueError("Invalid metrics")
@property @property
def dataset(self): def dataset(self):