add support for custom loss

This commit is contained in:
shahules786 2022-11-02 17:57:30 +05:30
parent 7f3dcf39c5
commit 2f85f48d69
1 changed files with 16 additions and 10 deletions

View File

@ -2,7 +2,7 @@ import os
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import List, Optional, Text, Union
from typing import Any, List, Optional, Text, Union
from urllib.parse import urlparse
import numpy as np
@ -10,6 +10,7 @@ import pytorch_lightning as pl
import torch
from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
metric: Union[str, List, Any] = "mse",
):
super().__init__()
assert (
@ -86,10 +87,11 @@ class Model(pl.LightningModule):
@metric.setter
def metric(self, metric):
self._metric = []
if isinstance(metric, str):
if isinstance(metric, (str, nn.Module)):
metric = [metric]
for func in metric:
if isinstance(func, str):
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
@ -97,9 +99,13 @@ class Model(pl.LightningModule):
)
else:
self._metric.append(LOSS_MAP[func]())
else:
raise ValueError(f"Invalid metrics {func}")
ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else:
raise ValueError("Invalid metrics")
@property
def dataset(self):