merge dev

This commit is contained in:
shahules786 2022-10-22 11:21:22 +05:30
commit 4b34cf6980
5 changed files with 90 additions and 33 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
#local
*.ckpt
*_local.yaml
cli/train_config/dataset/Vctk_local.yaml
.DS_Store
outputs/

View File

@ -9,7 +9,7 @@ benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: -1
devices: 1
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True

View File

@ -1,9 +1,12 @@
import math
import multiprocessing
import os
import random
from itertools import chain, cycle
from typing import Optional
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
@ -55,6 +58,7 @@ class TaskDataset(pl.LightningDataModule):
files: Files,
valid_minutes: float = 0.20,
duration: float = 1.0,
stride=None,
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
@ -65,6 +69,7 @@ class TaskDataset(pl.LightningDataModule):
self.name = name
self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration
self.stride = stride or duration
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
@ -72,6 +77,7 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
print("num_workers-main", self.num_workers)
if valid_minutes > 0.0:
self.valid_minutes = valid_minutes
else:
@ -90,11 +96,15 @@ class TaskDataset(pl.LightningDataModule):
self.name, train_clean, train_noisy, self.matching_function
)
train_data = fp.prepare_matching_dict()
self.train_data, self.val_data = self.train_valid_split(
train_data, self.val_data = self.train_valid_split(
train_data, valid_minutes=self.valid_minutes, random_state=42
)
self.train_data = self.prepare_traindata(train_data)
self._validation = self.prepare_mapstype(self.val_data)
print(
"train_data_size", sum([len(item) for item in self.train_data])
)
test_clean = os.path.join(self.root_dir, self.files.test_clean)
test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
@ -126,6 +136,32 @@ class TaskDataset(pl.LightningDataModule):
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
return train_data, valid_data
def prepare_traindata(self, data):
train_data = []
for item in data:
samples_metadata = []
clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments(
total_dur, self.duration, self.stride
)
for index in range(num_segments):
start = index * self.stride
samples_metadata.append(
({"clean": clean, "noisy": noisy}, start)
)
train_data.append(samples_metadata)
return train_data
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
def prepare_mapstype(self, data):
metadata = []
@ -142,11 +178,32 @@ class TaskDataset(pl.LightningDataModule):
)
return metadata
def train_collatefn(self, batch):
output = {"noisy": [], "clean": []}
for item in batch:
output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"])
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
return output
def worker_init_fn(self, _):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
dataset.data = dataset.dataset.train_data[
worker_id * split_size : (worker_id + 1) * split_size
]
def train_dataloader(self):
return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
batch_size=None,
num_workers=self.num_workers,
collate_fn=self.train_collatefn,
worker_init_fn=self.worker_init_fn,
)
def val_dataloader(self):
@ -227,24 +284,25 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage)
def random_sample(self, train_data):
return random.sample(train_data, len(train_data))
def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
while True:
file_dict, *_ = rng.choices(
self.train_data,
k=1,
weights=[file["duration"] for file in self.train_data],
train_data = rng.sample(self.train_data, len(self.train_data))
return zip(
*[
self.get_stream(self.random_sample(train_data))
for i in range(self.batch_size)
]
)
file_duration = file_dict["duration"]
num_segments = self.get_num_segments(
file_duration, self.duration, self.stride
)
for index in range(0, num_segments):
start_time = index * self.stride
yield self.prepare_segment(file_dict, start_time)
def get_stream(self, data):
return chain.from_iterable(map(self.process_data, cycle(data)))
def process_data(self, data):
for item in data:
yield self.prepare_segment(*item)
@staticmethod
def get_num_segments(file_duration, duration, stride):
@ -288,20 +346,14 @@ class EnhancerDataset(TaskDataset):
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment}
return {
"clean": clean_segment,
"noisy": noisy_segment,
"name": file_dict["clean"].split("/")[-1] + "->" + str(start_time),
}
def train__len__(self):
return math.ceil(
sum(
[
self.get_num_segments(
file["duration"], self.duration, self.stride
)
for file in self.train_data
]
)
)
return sum([len(item) for item in self.train_data]) // (self.batch_size)
def val__len__(self):
return len(self._validation)

View File

@ -3,7 +3,7 @@ import logging
import numpy as np
import torch
import torch.nn as nn
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from pesq import pesq
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
@ -122,7 +122,6 @@ class Pesq:
self.sr = sr
self.name = "pesq"
self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
@ -130,7 +129,12 @@ class Pesq:
for pred, target_ in zip(prediction, target):
try:
pesq_values.append(
self.pesq(pred.squeeze(), target_.squeeze()).item()
pesq(
self.sr,
target_.squeeze().detach().numpy(),
pred.squeeze().detach().numpy(),
self.mode,
)
)
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")

View File

@ -5,7 +5,7 @@ joblib>=1.2.0
librosa>=0.9.2
mlflow>=1.29.0
numpy>=1.23.3
pesq==0.0.4
git+https://github.com/ludlows/python-pesq#egg=pesq
protobuf>=3.19.6
pystoi==0.3.3
pytest-lazy-fixture>=0.6.3