Merge pull request #17 from shahules786/dev-datafix
foolproof iteration
This commit is contained in:
commit
a1445b0a95
|
|
@ -4,10 +4,16 @@ from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import (
|
||||||
|
EarlyStopping,
|
||||||
|
LearningRateMonitor,
|
||||||
|
ModelCheckpoint,
|
||||||
|
)
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
# from torch_audiomentations import Compose, Shift
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
||||||
|
|
@ -25,8 +31,13 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
parameters = config.hyperparameters
|
||||||
|
# apply_augmentations = Compose(
|
||||||
|
# [
|
||||||
|
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
|
||||||
dataset = instantiate(config.dataset)
|
dataset = instantiate(config.dataset, augmentations=None)
|
||||||
model = instantiate(
|
model = instantiate(
|
||||||
config.model,
|
config.model,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|
@ -45,6 +56,8 @@ def main(config: DictConfig):
|
||||||
every_n_epochs=1,
|
every_n_epochs=1,
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
|
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||||
|
|
||||||
if parameters.get("Early_stop", False):
|
if parameters.get("Early_stop", False):
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
|
|
@ -56,11 +69,11 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
callbacks.append(early_stopping)
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
def configure_optimizer(self):
|
def configure_optimizers(self):
|
||||||
optimizer = instantiate(
|
optimizer = instantiate(
|
||||||
config.optimizer,
|
config.optimizer,
|
||||||
lr=parameters.get("lr"),
|
lr=parameters.get("lr"),
|
||||||
parameters=self.parameters(),
|
params=self.parameters(),
|
||||||
)
|
)
|
||||||
scheduler = ReduceLROnPlateau(
|
scheduler = ReduceLROnPlateau(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
|
@ -70,9 +83,13 @@ def main(config: DictConfig):
|
||||||
min_lr=parameters.get("min_lr", 1e-6),
|
min_lr=parameters.get("min_lr", 1e-6),
|
||||||
patience=parameters.get("ReduceLr_patience", 3),
|
patience=parameters.get("ReduceLr_patience", 3),
|
||||||
)
|
)
|
||||||
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": scheduler,
|
||||||
|
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||||
|
}
|
||||||
|
|
||||||
model.configure_parameters = MethodType(configure_optimizer, model)
|
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||||
|
|
||||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,10 @@ _target_: enhancer.data.dataset.EnhancerDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 4.5
|
duration : 4.5
|
||||||
stride : 0.5
|
stride : 2
|
||||||
sampling_rate: 16000
|
sampling_rate: 16000
|
||||||
batch_size: 4
|
batch_size: 32
|
||||||
valid_minutes : 1
|
valid_minutes : 15
|
||||||
files:
|
files:
|
||||||
train_clean : clean_trainset_28spk_wav
|
train_clean : clean_trainset_28spk_wav
|
||||||
test_clean : clean_testset_wav
|
test_clean : clean_testset_wav
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
loss : mse
|
loss : mae
|
||||||
metric : mae
|
metric : [stoi,pesq,si-sdr]
|
||||||
lr : 0.0001
|
lr : 0.0003
|
||||||
ReduceLr_patience : 5
|
ReduceLr_patience : 5
|
||||||
ReduceLr_factor : 0.1
|
ReduceLr_factor : 0.2
|
||||||
min_lr : 0.000001
|
min_lr : 0.000001
|
||||||
EarlyStopping_factor : 10
|
EarlyStopping_factor : 10
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
experiment_name : shahules/enhancer
|
experiment_name : shahules/enhancer
|
||||||
run_name : baseline
|
run_name : Demucs + Vtck with stride + augmentations
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
_target_: enhancer.models.demucs.Demucs
|
_target_: enhancer.models.demucs.Demucs
|
||||||
num_channels: 1
|
num_channels: 1
|
||||||
resample: 2
|
resample: 4
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
|
|
||||||
encoder_decoder:
|
encoder_decoder:
|
||||||
depth: 5
|
depth: 4
|
||||||
initial_output_channels: 32
|
initial_output_channels: 64
|
||||||
kernel_size: 8
|
kernel_size: 8
|
||||||
stride: 4
|
stride: 4
|
||||||
growth_factor: 2
|
growth_factor: 2
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
_target_: enhancer.models.waveunet.WaveUnet
|
_target_: enhancer.models.waveunet.WaveUnet
|
||||||
num_channels : 1
|
num_channels : 1
|
||||||
depth : 12
|
depth : 9
|
||||||
initial_output_channels: 24
|
initial_output_channels: 24
|
||||||
sampling_rate : 16000
|
sampling_rate : 16000
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
_target_: pytorch_lightning.Trainer
|
_target_: pytorch_lightning.Trainer
|
||||||
accelerator: auto
|
accelerator: gpu
|
||||||
accumulate_grad_batches: 1
|
accumulate_grad_batches: 1
|
||||||
amp_backend: native
|
amp_backend: native
|
||||||
auto_lr_find: True
|
auto_lr_find: True
|
||||||
|
|
@ -9,7 +9,7 @@ benchmark: False
|
||||||
check_val_every_n_epoch: 1
|
check_val_every_n_epoch: 1
|
||||||
detect_anomaly: False
|
detect_anomaly: False
|
||||||
deterministic: False
|
deterministic: False
|
||||||
devices: 1
|
devices: 2
|
||||||
enable_checkpointing: True
|
enable_checkpointing: True
|
||||||
enable_model_summary: True
|
enable_model_summary: True
|
||||||
enable_progress_bar: True
|
enable_progress_bar: True
|
||||||
|
|
@ -23,7 +23,7 @@ limit_test_batches: 1.0
|
||||||
limit_train_batches: 1.0
|
limit_train_batches: 1.0
|
||||||
limit_val_batches: 1.0
|
limit_val_batches: 1.0
|
||||||
log_every_n_steps: 50
|
log_every_n_steps: 50
|
||||||
max_epochs: 3
|
max_epochs: 200
|
||||||
max_steps: -1
|
max_steps: -1
|
||||||
max_time: null
|
max_time: null
|
||||||
min_epochs: 1
|
min_epochs: 1
|
||||||
|
|
@ -38,7 +38,7 @@ precision: 32
|
||||||
profiler: null
|
profiler: null
|
||||||
reload_dataloaders_every_n_epochs: 0
|
reload_dataloaders_every_n_epochs: 0
|
||||||
replace_sampler_ddp: True
|
replace_sampler_ddp: True
|
||||||
strategy: null
|
strategy: ddp
|
||||||
sync_batchnorm: False
|
sync_batchnorm: False
|
||||||
tpu_cores: null
|
tpu_cores: null
|
||||||
track_grad_norm: -1
|
track_grad_norm: -1
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
from pathlib import Path
|
||||||
from itertools import chain, cycle
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torch_audiomentations import Compose
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
from enhancer.utils import check_files
|
from enhancer.utils import check_files
|
||||||
|
|
@ -16,13 +17,15 @@ from enhancer.utils.config import Files
|
||||||
from enhancer.utils.io import Audio
|
from enhancer.utils.io import Audio
|
||||||
from enhancer.utils.random import create_unique_rng
|
from enhancer.utils.random import create_unique_rng
|
||||||
|
|
||||||
|
LARGE_NUM = 2147483647
|
||||||
|
|
||||||
class TrainDataset(IterableDataset):
|
|
||||||
|
class TrainDataset(Dataset):
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
def __iter__(self):
|
def __getitem__(self, idx):
|
||||||
return self.dataset.train__iter__()
|
return self.dataset.train__getitem__(idx)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.dataset.train__len__()
|
return self.dataset.train__len__()
|
||||||
|
|
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
augmentations: Optional[Compose] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
else:
|
else:
|
||||||
raise ValueError("valid_minutes must be greater than 0")
|
raise ValueError("valid_minutes must be greater than 0")
|
||||||
|
|
||||||
|
self.augmentations = augmentations
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
prepare train/validation/test data splits
|
prepare train/validation/test data splits
|
||||||
|
|
@ -115,16 +121,29 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
):
|
):
|
||||||
|
|
||||||
valid_minutes *= 60
|
valid_minutes *= 60
|
||||||
valid_min_now = 0.0
|
valid_sec_now = 0.0
|
||||||
valid_indices = []
|
valid_indices = []
|
||||||
random_indices = list(range(0, len(data)))
|
all_speakers = np.unique(
|
||||||
rng = create_unique_rng(random_state)
|
[
|
||||||
rng.shuffle(random_indices)
|
(Path(file["clean"]).name.split("_")[0], file["duration"])
|
||||||
i = 0
|
for file in data
|
||||||
while valid_min_now <= valid_minutes:
|
]
|
||||||
valid_indices.append(random_indices[i])
|
)
|
||||||
valid_min_now += data[random_indices[i]]["duration"]
|
possible_indices = list(range(0, len(all_speakers)))
|
||||||
i += 1
|
rng = create_unique_rng(len(all_speakers))
|
||||||
|
|
||||||
|
while valid_sec_now <= valid_minutes:
|
||||||
|
speaker_index = rng.choice(possible_indices)
|
||||||
|
possible_indices.remove(speaker_index)
|
||||||
|
speaker_name = all_speakers[speaker_index]
|
||||||
|
file_indices = [
|
||||||
|
i
|
||||||
|
for i, file in enumerate(data)
|
||||||
|
if speaker_name == Path(file["clean"]).name.split("_")[0]
|
||||||
|
]
|
||||||
|
for i in file_indices:
|
||||||
|
valid_indices.append(i)
|
||||||
|
valid_sec_now += data[i]["duration"]
|
||||||
|
|
||||||
train_data = [
|
train_data = [
|
||||||
item for i, item in enumerate(data) if i not in valid_indices
|
item for i, item in enumerate(data) if i not in valid_indices
|
||||||
|
|
@ -135,16 +154,11 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
def prepare_traindata(self, data):
|
def prepare_traindata(self, data):
|
||||||
train_data = []
|
train_data = []
|
||||||
for item in data:
|
for item in data:
|
||||||
samples_metadata = []
|
|
||||||
clean, noisy, total_dur = item.values()
|
clean, noisy, total_dur = item.values()
|
||||||
num_segments = self.get_num_segments(
|
num_segments = self.get_num_segments(
|
||||||
total_dur, self.duration, self.stride
|
total_dur, self.duration, self.stride
|
||||||
)
|
)
|
||||||
for index in range(num_segments):
|
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
|
||||||
start = index * self.stride
|
|
||||||
samples_metadata.append(
|
|
||||||
({"clean": clean, "noisy": noisy}, start)
|
|
||||||
)
|
|
||||||
train_data.append(samples_metadata)
|
train_data.append(samples_metadata)
|
||||||
return train_data
|
return train_data
|
||||||
|
|
||||||
|
|
@ -166,7 +180,9 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if total_dur < self.duration:
|
if total_dur < self.duration:
|
||||||
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
||||||
else:
|
else:
|
||||||
num_segments = round(total_dur / self.duration)
|
num_segments = self.get_num_segments(
|
||||||
|
total_dur, self.duration, self.duration
|
||||||
|
)
|
||||||
for index in range(num_segments):
|
for index in range(num_segments):
|
||||||
start_time = index * self.duration
|
start_time = index * self.duration
|
||||||
metadata.append(
|
metadata.append(
|
||||||
|
|
@ -175,31 +191,44 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def train_collatefn(self, batch):
|
def train_collatefn(self, batch):
|
||||||
output = {"noisy": [], "clean": []}
|
|
||||||
|
output = {"clean": [], "noisy": []}
|
||||||
for item in batch:
|
for item in batch:
|
||||||
output["noisy"].append(item["noisy"])
|
|
||||||
output["clean"].append(item["clean"])
|
output["clean"].append(item["clean"])
|
||||||
|
output["noisy"].append(item["noisy"])
|
||||||
|
|
||||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||||
|
|
||||||
|
if self.augmentations is not None:
|
||||||
|
noise = output["noisy"] - output["clean"]
|
||||||
|
output["clean"] = self.augmentations(
|
||||||
|
output["clean"], sample_rate=self.sampling_rate
|
||||||
|
)
|
||||||
|
self.augmentations.freeze_parameters()
|
||||||
|
output["noisy"] = (
|
||||||
|
self.augmentations(noise, sample_rate=self.sampling_rate)
|
||||||
|
+ output["clean"]
|
||||||
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def worker_init_fn(self, _):
|
@property
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
def generator(self):
|
||||||
dataset = worker_info.dataset
|
generator = torch.Generator()
|
||||||
worker_id = worker_info.id
|
if hasattr(self, "model"):
|
||||||
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
|
seed = self.model.current_epoch + LARGE_NUM
|
||||||
dataset.data = dataset.dataset.train_data[
|
else:
|
||||||
worker_id * split_size : (worker_id + 1) * split_size
|
seed = LARGE_NUM
|
||||||
]
|
return generator.manual_seed(seed)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
TrainDataset(self),
|
TrainDataset(self),
|
||||||
batch_size=None,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
|
generator=self.generator,
|
||||||
collate_fn=self.train_collatefn,
|
collate_fn=self.train_collatefn,
|
||||||
worker_init_fn=self.worker_init_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
|
|
@ -251,11 +280,12 @@ class EnhancerDataset(TaskDataset):
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes=5.0,
|
valid_minutes=5.0,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
stride=0.5,
|
stride=None,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
augmentations: Optional[Compose] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -268,6 +298,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
augmentations=augmentations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
|
@ -280,35 +311,17 @@ class EnhancerDataset(TaskDataset):
|
||||||
|
|
||||||
super().setup(stage=stage)
|
super().setup(stage=stage)
|
||||||
|
|
||||||
def random_sample(self, train_data):
|
def train__getitem__(self, idx):
|
||||||
return random.sample(train_data, len(train_data))
|
|
||||||
|
|
||||||
def train__iter__(self):
|
for filedict, num_samples in self.train_data:
|
||||||
rng = create_unique_rng(self.model.current_epoch)
|
if idx >= num_samples:
|
||||||
train_data = rng.sample(self.train_data, len(self.train_data))
|
idx -= num_samples
|
||||||
return zip(
|
continue
|
||||||
*[
|
|
||||||
self.get_stream(self.random_sample(train_data))
|
|
||||||
for i in range(self.batch_size)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_stream(self, data):
|
|
||||||
return chain.from_iterable(map(self.process_data, cycle(data)))
|
|
||||||
|
|
||||||
def process_data(self, data):
|
|
||||||
for item in data:
|
|
||||||
yield self.prepare_segment(*item)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_num_segments(file_duration, duration, stride):
|
|
||||||
|
|
||||||
if file_duration < duration:
|
|
||||||
num_segments = 1
|
|
||||||
else:
|
else:
|
||||||
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
start = 0
|
||||||
|
if self.duration is not None:
|
||||||
return num_segments
|
start = idx * self.stride
|
||||||
|
return self.prepare_segment(filedict, start)
|
||||||
|
|
||||||
def val__getitem__(self, idx):
|
def val__getitem__(self, idx):
|
||||||
return self.prepare_segment(*self._validation[idx])
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
@ -348,7 +361,8 @@ class EnhancerDataset(TaskDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
return sum([len(item) for item in self.train_data]) // (self.batch_size)
|
_, num_examples = list(zip(*self.train_data))
|
||||||
|
return sum(num_examples)
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pesq import pesq
|
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||||
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,7 +66,7 @@ class Si_SDR:
|
||||||
"Invalid reduction, valid options are sum, mean, None"
|
"Invalid reduction, valid options are sum, mean, None"
|
||||||
)
|
)
|
||||||
self.higher_better = False
|
self.higher_better = False
|
||||||
self.name = "Si-SDR"
|
self.name = "si-sdr"
|
||||||
|
|
||||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -122,20 +122,16 @@ class Pesq:
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
self.name = "pesq"
|
self.name = "pesq"
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
self.pesq = PerceptualEvaluationSpeechQuality(
|
||||||
|
fs=self.sr, mode=self.mode
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
pesq_values = []
|
pesq_values = []
|
||||||
for pred, target_ in zip(prediction, target):
|
for pred, target_ in zip(prediction, target):
|
||||||
try:
|
try:
|
||||||
pesq_values.append(
|
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||||
pesq(
|
|
||||||
self.sr,
|
|
||||||
target_.squeeze().detach().cpu().numpy(),
|
|
||||||
pred.squeeze().detach().cpu().numpy(),
|
|
||||||
self.mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"{e} error occured while calculating PESQ")
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
return torch.tensor(np.mean(pesq_values))
|
return torch.tensor(np.mean(pesq_values))
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,8 @@ class Model(pl.LightningModule):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
|
self.dataset.model = self
|
||||||
|
|
||||||
print(
|
print(
|
||||||
"Total train duration",
|
"Total train duration",
|
||||||
self.dataset.train_dataloader().dataset.__len__()
|
self.dataset.train_dataloader().dataset.__len__()
|
||||||
|
|
@ -134,7 +136,6 @@ class Model(pl.LightningModule):
|
||||||
/ 60,
|
/ 60,
|
||||||
"minutes",
|
"minutes",
|
||||||
)
|
)
|
||||||
self.dataset.model = self
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ class Audio:
|
||||||
|
|
||||||
if sampling_rate:
|
if sampling_rate:
|
||||||
audio = self.__class__.resample_audio(
|
audio = self.__class__.resample_audio(
|
||||||
audio, self.sampling_rate, sampling_rate
|
audio, sampling_rate, self.sampling_rate
|
||||||
)
|
)
|
||||||
if self.return_tensor:
|
if self.return_tensor:
|
||||||
return torch.tensor(audio)
|
return torch.tensor(audio)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ joblib>=1.2.0
|
||||||
librosa>=0.9.2
|
librosa>=0.9.2
|
||||||
mlflow>=1.29.0
|
mlflow>=1.29.0
|
||||||
numpy>=1.23.3
|
numpy>=1.23.3
|
||||||
git+https://github.com/ludlows/python-pesq#egg=pesq
|
pesq==0.0.4
|
||||||
protobuf>=3.19.6
|
protobuf>=3.19.6
|
||||||
pystoi==0.3.3
|
pystoi==0.3.3
|
||||||
pytest-lazy-fixture>=0.6.3
|
pytest-lazy-fixture>=0.6.3
|
||||||
|
|
@ -14,5 +14,6 @@ scikit-learn>=1.1.2
|
||||||
scipy>=1.9.1
|
scipy>=1.9.1
|
||||||
soundfile>=0.11.0
|
soundfile>=0.11.0
|
||||||
torch>=1.12.1
|
torch>=1.12.1
|
||||||
|
torch-audiomentations==0.11.0
|
||||||
torchaudio>=0.12.1
|
torchaudio>=0.12.1
|
||||||
tqdm>=4.64.1
|
tqdm>=4.64.1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue