Merge pull request #19 from shahules786/dev-loss

Support custom loss functions
This commit is contained in:
Shahul ES 2022-11-03 09:53:25 +05:30 committed by GitHub
commit a082474034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 12 deletions

View File

@ -350,7 +350,6 @@ class EnhancerDataset(TaskDataset):
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
mode=self.padding_mode,
)
noisy_segment = F.pad(
noisy_segment,

View File

@ -2,7 +2,7 @@ import os
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import List, Optional, Text, Union
from typing import Any, List, Optional, Text, Union
from urllib.parse import urlparse
import numpy as np
@ -10,6 +10,7 @@ import pytorch_lightning as pl
import torch
from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset
@ -36,7 +37,7 @@ class Model(pl.LightningModule):
Enhancer dataset used for training/validation
duration: float, optional
duration used for training/inference
loss : string or List of strings, default to "mse"
loss : string or List of strings or custom loss (nn.Module), default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR")
"""
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
metric: Union[str, List, Any] = "mse",
):
super().__init__()
assert (
@ -86,20 +87,25 @@ class Model(pl.LightningModule):
@metric.setter
def metric(self, metric):
self._metric = []
if isinstance(metric, str):
if isinstance(metric, (str, nn.Module)):
metric = [metric]
for func in metric:
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
LOSS_MAP[func](self.hparams.sampling_rate)
)
if isinstance(func, str):
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
LOSS_MAP[func](self.hparams.sampling_rate)
)
else:
self._metric.append(LOSS_MAP[func]())
else:
self._metric.append(LOSS_MAP[func]())
ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else:
raise ValueError(f"Invalid metrics {func}")
raise ValueError("Invalid metrics")
@property
def dataset(self):