add support for custom loss
This commit is contained in:
parent
7f3dcf39c5
commit
2f85f48d69
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Text, Union
|
from typing import Any, List, Optional, Text, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -10,6 +10,7 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
|
from torch import nn
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
duration: Optional[float] = None,
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List, Any] = "mse",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -86,20 +87,25 @@ class Model(pl.LightningModule):
|
||||||
@metric.setter
|
@metric.setter
|
||||||
def metric(self, metric):
|
def metric(self, metric):
|
||||||
self._metric = []
|
self._metric = []
|
||||||
if isinstance(metric, str):
|
if isinstance(metric, (str, nn.Module)):
|
||||||
metric = [metric]
|
metric = [metric]
|
||||||
|
|
||||||
for func in metric:
|
for func in metric:
|
||||||
if func in LOSS_MAP.keys():
|
if isinstance(func, str):
|
||||||
if func in ("pesq", "stoi"):
|
if func in LOSS_MAP.keys():
|
||||||
self._metric.append(
|
if func in ("pesq", "stoi"):
|
||||||
LOSS_MAP[func](self.hparams.sampling_rate)
|
self._metric.append(
|
||||||
)
|
LOSS_MAP[func](self.hparams.sampling_rate)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._metric.append(LOSS_MAP[func]())
|
||||||
else:
|
else:
|
||||||
self._metric.append(LOSS_MAP[func]())
|
ValueError(f"Invalid metrics {func}")
|
||||||
|
|
||||||
|
elif isinstance(func, nn.Module):
|
||||||
|
self._metric.append(func)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid metrics {func}")
|
raise ValueError("Invalid metrics")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue