merge dev
This commit is contained in:
commit
4b34cf6980
|
|
@ -1,5 +1,6 @@
|
||||||
#local
|
#local
|
||||||
*.ckpt
|
*.ckpt
|
||||||
|
*_local.yaml
|
||||||
cli/train_config/dataset/Vctk_local.yaml
|
cli/train_config/dataset/Vctk_local.yaml
|
||||||
.DS_Store
|
.DS_Store
|
||||||
outputs/
|
outputs/
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ benchmark: False
|
||||||
check_val_every_n_epoch: 1
|
check_val_every_n_epoch: 1
|
||||||
detect_anomaly: False
|
detect_anomaly: False
|
||||||
deterministic: False
|
deterministic: False
|
||||||
devices: -1
|
devices: 1
|
||||||
enable_checkpointing: True
|
enable_checkpointing: True
|
||||||
enable_model_summary: True
|
enable_model_summary: True
|
||||||
enable_progress_bar: True
|
enable_progress_bar: True
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,12 @@
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
from itertools import chain, cycle
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||||
|
|
||||||
|
|
@ -55,6 +58,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes: float = 0.20,
|
valid_minutes: float = 0.20,
|
||||||
duration: float = 1.0,
|
duration: float = 1.0,
|
||||||
|
stride=None,
|
||||||
sampling_rate: int = 48000,
|
sampling_rate: int = 48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
|
@ -65,6 +69,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.files, self.root_dir = check_files(root_dir, files)
|
self.files, self.root_dir = check_files(root_dir, files)
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
|
self.stride = stride or duration
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.matching_function = matching_function
|
self.matching_function = matching_function
|
||||||
|
|
@ -72,6 +77,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
print("num_workers-main", self.num_workers)
|
||||||
if valid_minutes > 0.0:
|
if valid_minutes > 0.0:
|
||||||
self.valid_minutes = valid_minutes
|
self.valid_minutes = valid_minutes
|
||||||
else:
|
else:
|
||||||
|
|
@ -90,11 +96,15 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self.name, train_clean, train_noisy, self.matching_function
|
self.name, train_clean, train_noisy, self.matching_function
|
||||||
)
|
)
|
||||||
train_data = fp.prepare_matching_dict()
|
train_data = fp.prepare_matching_dict()
|
||||||
self.train_data, self.val_data = self.train_valid_split(
|
train_data, self.val_data = self.train_valid_split(
|
||||||
train_data, valid_minutes=self.valid_minutes, random_state=42
|
train_data, valid_minutes=self.valid_minutes, random_state=42
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.train_data = self.prepare_traindata(train_data)
|
||||||
self._validation = self.prepare_mapstype(self.val_data)
|
self._validation = self.prepare_mapstype(self.val_data)
|
||||||
|
print(
|
||||||
|
"train_data_size", sum([len(item) for item in self.train_data])
|
||||||
|
)
|
||||||
|
|
||||||
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||||
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||||
|
|
@ -126,6 +136,32 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
|
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
|
||||||
return train_data, valid_data
|
return train_data, valid_data
|
||||||
|
|
||||||
|
def prepare_traindata(self, data):
|
||||||
|
train_data = []
|
||||||
|
for item in data:
|
||||||
|
samples_metadata = []
|
||||||
|
clean, noisy, total_dur = item.values()
|
||||||
|
num_segments = self.get_num_segments(
|
||||||
|
total_dur, self.duration, self.stride
|
||||||
|
)
|
||||||
|
for index in range(num_segments):
|
||||||
|
start = index * self.stride
|
||||||
|
samples_metadata.append(
|
||||||
|
({"clean": clean, "noisy": noisy}, start)
|
||||||
|
)
|
||||||
|
train_data.append(samples_metadata)
|
||||||
|
return train_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_num_segments(file_duration, duration, stride):
|
||||||
|
|
||||||
|
if file_duration < duration:
|
||||||
|
num_segments = 1
|
||||||
|
else:
|
||||||
|
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
||||||
|
|
||||||
|
return num_segments
|
||||||
|
|
||||||
def prepare_mapstype(self, data):
|
def prepare_mapstype(self, data):
|
||||||
|
|
||||||
metadata = []
|
metadata = []
|
||||||
|
|
@ -142,11 +178,32 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
def train_collatefn(self, batch):
|
||||||
|
output = {"noisy": [], "clean": []}
|
||||||
|
for item in batch:
|
||||||
|
output["noisy"].append(item["noisy"])
|
||||||
|
output["clean"].append(item["clean"])
|
||||||
|
|
||||||
|
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||||
|
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def worker_init_fn(self, _):
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
dataset = worker_info.dataset
|
||||||
|
worker_id = worker_info.id
|
||||||
|
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
|
||||||
|
dataset.data = dataset.dataset.train_data[
|
||||||
|
worker_id * split_size : (worker_id + 1) * split_size
|
||||||
|
]
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
TrainDataset(self),
|
TrainDataset(self),
|
||||||
batch_size=self.batch_size,
|
batch_size=None,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
|
collate_fn=self.train_collatefn,
|
||||||
|
worker_init_fn=self.worker_init_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
|
|
@ -227,24 +284,25 @@ class EnhancerDataset(TaskDataset):
|
||||||
|
|
||||||
super().setup(stage=stage)
|
super().setup(stage=stage)
|
||||||
|
|
||||||
|
def random_sample(self, train_data):
|
||||||
|
return random.sample(train_data, len(train_data))
|
||||||
|
|
||||||
def train__iter__(self):
|
def train__iter__(self):
|
||||||
|
|
||||||
rng = create_unique_rng(self.model.current_epoch)
|
rng = create_unique_rng(self.model.current_epoch)
|
||||||
|
train_data = rng.sample(self.train_data, len(self.train_data))
|
||||||
while True:
|
return zip(
|
||||||
|
*[
|
||||||
file_dict, *_ = rng.choices(
|
self.get_stream(self.random_sample(train_data))
|
||||||
self.train_data,
|
for i in range(self.batch_size)
|
||||||
k=1,
|
]
|
||||||
weights=[file["duration"] for file in self.train_data],
|
|
||||||
)
|
)
|
||||||
file_duration = file_dict["duration"]
|
|
||||||
num_segments = self.get_num_segments(
|
def get_stream(self, data):
|
||||||
file_duration, self.duration, self.stride
|
return chain.from_iterable(map(self.process_data, cycle(data)))
|
||||||
)
|
|
||||||
for index in range(0, num_segments):
|
def process_data(self, data):
|
||||||
start_time = index * self.stride
|
for item in data:
|
||||||
yield self.prepare_segment(file_dict, start_time)
|
yield self.prepare_segment(*item)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_num_segments(file_duration, duration, stride):
|
def get_num_segments(file_duration, duration, stride):
|
||||||
|
|
@ -288,20 +346,14 @@ class EnhancerDataset(TaskDataset):
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return {"clean": clean_segment, "noisy": noisy_segment}
|
return {
|
||||||
|
"clean": clean_segment,
|
||||||
|
"noisy": noisy_segment,
|
||||||
|
"name": file_dict["clean"].split("/")[-1] + "->" + str(start_time),
|
||||||
|
}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
|
return sum([len(item) for item in self.train_data]) // (self.batch_size)
|
||||||
return math.ceil(
|
|
||||||
sum(
|
|
||||||
[
|
|
||||||
self.get_num_segments(
|
|
||||||
file["duration"], self.duration, self.stride
|
|
||||||
)
|
|
||||||
for file in self.train_data
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
from pesq import pesq
|
||||||
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,7 +122,6 @@ class Pesq:
|
||||||
self.sr = sr
|
self.sr = sr
|
||||||
self.name = "pesq"
|
self.name = "pesq"
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
|
|
||||||
|
|
||||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||||
|
|
||||||
|
|
@ -130,7 +129,12 @@ class Pesq:
|
||||||
for pred, target_ in zip(prediction, target):
|
for pred, target_ in zip(prediction, target):
|
||||||
try:
|
try:
|
||||||
pesq_values.append(
|
pesq_values.append(
|
||||||
self.pesq(pred.squeeze(), target_.squeeze()).item()
|
pesq(
|
||||||
|
self.sr,
|
||||||
|
target_.squeeze().detach().numpy(),
|
||||||
|
pred.squeeze().detach().numpy(),
|
||||||
|
self.mode,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"{e} error occured while calculating PESQ")
|
logging.warning(f"{e} error occured while calculating PESQ")
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ joblib>=1.2.0
|
||||||
librosa>=0.9.2
|
librosa>=0.9.2
|
||||||
mlflow>=1.29.0
|
mlflow>=1.29.0
|
||||||
numpy>=1.23.3
|
numpy>=1.23.3
|
||||||
pesq==0.0.4
|
git+https://github.com/ludlows/python-pesq#egg=pesq
|
||||||
protobuf>=3.19.6
|
protobuf>=3.19.6
|
||||||
pystoi==0.3.3
|
pystoi==0.3.3
|
||||||
pytest-lazy-fixture>=0.6.3
|
pytest-lazy-fixture>=0.6.3
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue