Merge pull request #19 from shahules786/dev-loss
Support custom loss functions
This commit is contained in:
commit
a082474034
|
|
@ -350,7 +350,6 @@ class EnhancerDataset(TaskDataset):
|
|||
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||
),
|
||||
),
|
||||
mode=self.padding_mode,
|
||||
)
|
||||
noisy_segment = F.pad(
|
||||
noisy_segment,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
from collections import defaultdict
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Text, Union
|
||||
from typing import Any, List, Optional, Text, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -10,6 +10,7 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
|
|
@ -36,7 +37,7 @@ class Model(pl.LightningModule):
|
|||
Enhancer dataset used for training/validation
|
||||
duration: float, optional
|
||||
duration used for training/inference
|
||||
loss : string or List of strings, default to "mse"
|
||||
loss : string or List of strings or custom loss (nn.Module), default to "mse"
|
||||
loss functions to be used. Available ("mse","mae","Si-SDR")
|
||||
|
||||
"""
|
||||
|
|
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
|
|||
dataset: Optional[EnhancerDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
metric: Union[str, List, Any] = "mse",
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
|
|
@ -86,10 +87,11 @@ class Model(pl.LightningModule):
|
|||
@metric.setter
|
||||
def metric(self, metric):
|
||||
self._metric = []
|
||||
if isinstance(metric, str):
|
||||
if isinstance(metric, (str, nn.Module)):
|
||||
metric = [metric]
|
||||
|
||||
for func in metric:
|
||||
if isinstance(func, str):
|
||||
if func in LOSS_MAP.keys():
|
||||
if func in ("pesq", "stoi"):
|
||||
self._metric.append(
|
||||
|
|
@ -97,9 +99,13 @@ class Model(pl.LightningModule):
|
|||
)
|
||||
else:
|
||||
self._metric.append(LOSS_MAP[func]())
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid metrics {func}")
|
||||
ValueError(f"Invalid metrics {func}")
|
||||
|
||||
elif isinstance(func, nn.Module):
|
||||
self._metric.append(func)
|
||||
else:
|
||||
raise ValueError("Invalid metrics")
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue