Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
94a4ea38ed
|
|
@ -1,4 +1,7 @@
|
||||||
# mayavoz
|
<p align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/25312635/195507951-fe64657c-9114-4d78-b04e-444e6d5bbcc4.png" />
|
||||||
|
</p>
|
||||||
|
|
||||||
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training .
|
mayavoz is a Pytorch-based opensource toolkit for speech enhancement. It is designed to save time for audio researchers. Is provides easy to use pretrained audio enhancement models and facilitates highly customisable custom model training .
|
||||||
|
|
||||||
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
| **[Quick Start]()** | **[Installation]()** | **[Tutorials]()** | **[Available Recipes]()**
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||||
|
|
@ -116,17 +117,24 @@ class Stoi:
|
||||||
|
|
||||||
|
|
||||||
class Pesq:
|
class Pesq:
|
||||||
def __init__(self, sr: int, mode="nb"):
|
def __init__(self, sr: int, mode="wb"):
|
||||||
|
|
||||||
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
self.sr = sr
|
||||||
self.name = "pesq"
|
self.name = "pesq"
|
||||||
|
self.mode = mode
|
||||||
|
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
||||||
|
|
||||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
pesq_values = []
|
||||||
|
for pred, target_ in zip(prediction, target):
|
||||||
try:
|
try:
|
||||||
return self.pesq(prediction, target)
|
pesq_values.append(
|
||||||
|
self.pesq(pred.squeeze(), target_.squeeze()).item()
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"{e} error occured while calculating PESQ")
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
return torch.tensor(0.0)
|
return torch.tensor(np.mean(pesq_values))
|
||||||
|
|
||||||
|
|
||||||
class LossWrapper(nn.Module):
|
class LossWrapper(nn.Module):
|
||||||
|
|
@ -177,7 +185,7 @@ class LossWrapper(nn.Module):
|
||||||
LOSS_MAP = {
|
LOSS_MAP = {
|
||||||
"mae": mean_absolute_error,
|
"mae": mean_absolute_error,
|
||||||
"mse": mean_squared_error,
|
"mse": mean_squared_error,
|
||||||
"SI-SDR": Si_SDR,
|
"si-sdr": Si_SDR,
|
||||||
"pesq": Pesq,
|
"pesq": Pesq,
|
||||||
"stoi": Stoi,
|
"stoi": Stoi,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue