Merge pull request #19 from shahules786/dev-loss

Support custom loss functions
This commit is contained in:
Shahul ES 2022-11-03 09:53:25 +05:30 committed by GitHub
commit a082474034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 12 deletions

View File

@ -350,7 +350,6 @@ class EnhancerDataset(TaskDataset):
self.duration * self.sampling_rate - clean_segment.shape[-1] self.duration * self.sampling_rate - clean_segment.shape[-1]
), ),
), ),
mode=self.padding_mode,
) )
noisy_segment = F.pad( noisy_segment = F.pad(
noisy_segment, noisy_segment,

View File

@ -2,7 +2,7 @@ import os
from collections import defaultdict from collections import defaultdict
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from typing import List, Optional, Text, Union from typing import Any, List, Optional, Text, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@ -10,6 +10,7 @@ import pytorch_lightning as pl
import torch import torch
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch import nn
from torch.optim import Adam from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
@ -36,7 +37,7 @@ class Model(pl.LightningModule):
Enhancer dataset used for training/validation Enhancer dataset used for training/validation
duration: float, optional duration: float, optional
duration used for training/inference duration used for training/inference
loss : string or List of strings, default to "mse" loss : string or List of strings or custom loss (nn.Module), default to "mse"
loss functions to be used. Available ("mse","mae","Si-SDR") loss functions to be used. Available ("mse","mae","Si-SDR")
""" """
@ -49,7 +50,7 @@ class Model(pl.LightningModule):
dataset: Optional[EnhancerDataset] = None, dataset: Optional[EnhancerDataset] = None,
duration: Optional[float] = None, duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List, Any] = "mse",
): ):
super().__init__() super().__init__()
assert ( assert (
@ -86,10 +87,11 @@ class Model(pl.LightningModule):
@metric.setter @metric.setter
def metric(self, metric): def metric(self, metric):
self._metric = [] self._metric = []
if isinstance(metric, str): if isinstance(metric, (str, nn.Module)):
metric = [metric] metric = [metric]
for func in metric: for func in metric:
if isinstance(func, str):
if func in LOSS_MAP.keys(): if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"): if func in ("pesq", "stoi"):
self._metric.append( self._metric.append(
@ -97,9 +99,13 @@ class Model(pl.LightningModule):
) )
else: else:
self._metric.append(LOSS_MAP[func]()) self._metric.append(LOSS_MAP[func]())
else: else:
raise ValueError(f"Invalid metrics {func}") ValueError(f"Invalid metrics {func}")
elif isinstance(func, nn.Module):
self._metric.append(func)
else:
raise ValueError("Invalid metrics")
@property @property
def dataset(self): def dataset(self):