Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-13 10:50:59 +05:30
commit 94a4ea38ed
2 changed files with 20 additions and 9 deletions

View File

@ -1,4 +1,7 @@
# mayavoz <p align="center">
<img src="https://user-images.githubusercontent.com/25312635/195507951-fe64657c-9114-4d78-b04e-444e6d5bbcc4.png" />
</p>
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training . mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training .
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()** | **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**

View File

@ -1,5 +1,6 @@
import logging import logging
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
@ -116,17 +117,24 @@ class Stoi:
class Pesq: class Pesq:
def __init__(self, sr: int, mode="nb"): def __init__(self, sr: int, mode="wb"):
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode) self.sr = sr
self.name = "pesq" self.name = "pesq"
self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor): def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
try:
return self.pesq(prediction, target) pesq_values = []
except Exception as e: for pred, target_ in zip(prediction, target):
logging.warning(f"{e} error occured while calculating PESQ") try:
return torch.tensor(0.0) pesq_values.append(
self.pesq(pred.squeeze(), target_.squeeze()).item()
)
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values))
class LossWrapper(nn.Module): class LossWrapper(nn.Module):
@ -177,7 +185,7 @@ class LossWrapper(nn.Module):
LOSS_MAP = { LOSS_MAP = {
"mae": mean_absolute_error, "mae": mean_absolute_error,
"mse": mean_squared_error, "mse": mean_squared_error,
"SI-SDR": Si_SDR, "si-sdr": Si_SDR,
"pesq": Pesq, "pesq": Pesq,
"stoi": Stoi, "stoi": Stoi,
} }