merge dev

This commit is contained in:
shahules786 2022-10-22 11:21:22 +05:30
commit 4b34cf6980
5 changed files with 90 additions and 33 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
#local #local
*.ckpt *.ckpt
*_local.yaml
cli/train_config/dataset/Vctk_local.yaml cli/train_config/dataset/Vctk_local.yaml
.DS_Store .DS_Store
outputs/ outputs/

View File

@ -9,7 +9,7 @@ benchmark: False
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
detect_anomaly: False detect_anomaly: False
deterministic: False deterministic: False
devices: -1 devices: 1
enable_checkpointing: True enable_checkpointing: True
enable_model_summary: True enable_model_summary: True
enable_progress_bar: True enable_progress_bar: True

View File

@ -1,9 +1,12 @@
import math import math
import multiprocessing import multiprocessing
import os import os
import random
from itertools import chain, cycle
from typing import Optional from typing import Optional
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data import DataLoader, Dataset, IterableDataset
@ -55,6 +58,7 @@ class TaskDataset(pl.LightningDataModule):
files: Files, files: Files,
valid_minutes: float = 0.20, valid_minutes: float = 0.20,
duration: float = 1.0, duration: float = 1.0,
stride=None,
sampling_rate: int = 48000, sampling_rate: int = 48000,
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
@ -65,6 +69,7 @@ class TaskDataset(pl.LightningDataModule):
self.name = name self.name = name
self.files, self.root_dir = check_files(root_dir, files) self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration self.duration = duration
self.stride = stride or duration
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.batch_size = batch_size self.batch_size = batch_size
self.matching_function = matching_function self.matching_function = matching_function
@ -72,6 +77,7 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
print("num_workers-main", self.num_workers)
if valid_minutes > 0.0: if valid_minutes > 0.0:
self.valid_minutes = valid_minutes self.valid_minutes = valid_minutes
else: else:
@ -90,11 +96,15 @@ class TaskDataset(pl.LightningDataModule):
self.name, train_clean, train_noisy, self.matching_function self.name, train_clean, train_noisy, self.matching_function
) )
train_data = fp.prepare_matching_dict() train_data = fp.prepare_matching_dict()
self.train_data, self.val_data = self.train_valid_split( train_data, self.val_data = self.train_valid_split(
train_data, valid_minutes=self.valid_minutes, random_state=42 train_data, valid_minutes=self.valid_minutes, random_state=42
) )
self.train_data = self.prepare_traindata(train_data)
self._validation = self.prepare_mapstype(self.val_data) self._validation = self.prepare_mapstype(self.val_data)
print(
"train_data_size", sum([len(item) for item in self.train_data])
)
test_clean = os.path.join(self.root_dir, self.files.test_clean) test_clean = os.path.join(self.root_dir, self.files.test_clean)
test_noisy = os.path.join(self.root_dir, self.files.test_noisy) test_noisy = os.path.join(self.root_dir, self.files.test_noisy)
@ -126,6 +136,32 @@ class TaskDataset(pl.LightningDataModule):
valid_data = [item for i, item in enumerate(data) if i in valid_indices] valid_data = [item for i, item in enumerate(data) if i in valid_indices]
return train_data, valid_data return train_data, valid_data
def prepare_traindata(self, data):
train_data = []
for item in data:
samples_metadata = []
clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments(
total_dur, self.duration, self.stride
)
for index in range(num_segments):
start = index * self.stride
samples_metadata.append(
({"clean": clean, "noisy": noisy}, start)
)
train_data.append(samples_metadata)
return train_data
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
def prepare_mapstype(self, data): def prepare_mapstype(self, data):
metadata = [] metadata = []
@ -142,11 +178,32 @@ class TaskDataset(pl.LightningDataModule):
) )
return metadata return metadata
def train_collatefn(self, batch):
output = {"noisy": [], "clean": []}
for item in batch:
output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"])
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
return output
def worker_init_fn(self, _):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
dataset.data = dataset.dataset.train_data[
worker_id * split_size : (worker_id + 1) * split_size
]
def train_dataloader(self): def train_dataloader(self):
return DataLoader( return DataLoader(
TrainDataset(self), TrainDataset(self),
batch_size=self.batch_size, batch_size=None,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=self.train_collatefn,
worker_init_fn=self.worker_init_fn,
) )
def val_dataloader(self): def val_dataloader(self):
@ -227,24 +284,25 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage) super().setup(stage=stage)
def random_sample(self, train_data):
return random.sample(train_data, len(train_data))
def train__iter__(self): def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch) rng = create_unique_rng(self.model.current_epoch)
train_data = rng.sample(self.train_data, len(self.train_data))
return zip(
*[
self.get_stream(self.random_sample(train_data))
for i in range(self.batch_size)
]
)
while True: def get_stream(self, data):
return chain.from_iterable(map(self.process_data, cycle(data)))
file_dict, *_ = rng.choices( def process_data(self, data):
self.train_data, for item in data:
k=1, yield self.prepare_segment(*item)
weights=[file["duration"] for file in self.train_data],
)
file_duration = file_dict["duration"]
num_segments = self.get_num_segments(
file_duration, self.duration, self.stride
)
for index in range(0, num_segments):
start_time = index * self.stride
yield self.prepare_segment(file_dict, start_time)
@staticmethod @staticmethod
def get_num_segments(file_duration, duration, stride): def get_num_segments(file_duration, duration, stride):
@ -288,20 +346,14 @@ class EnhancerDataset(TaskDataset):
), ),
), ),
) )
return {"clean": clean_segment, "noisy": noisy_segment} return {
"clean": clean_segment,
"noisy": noisy_segment,
"name": file_dict["clean"].split("/")[-1] + "->" + str(start_time),
}
def train__len__(self): def train__len__(self):
return sum([len(item) for item in self.train_data]) // (self.batch_size)
return math.ceil(
sum(
[
self.get_num_segments(
file["duration"], self.duration, self.stride
)
for file in self.train_data
]
)
)
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)

View File

@ -3,7 +3,7 @@ import logging
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality from pesq import pesq
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
@ -122,7 +122,6 @@ class Pesq:
self.sr = sr self.sr = sr
self.name = "pesq" self.name = "pesq"
self.mode = mode self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(fs=sr, mode=mode)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor): def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
@ -130,7 +129,12 @@ class Pesq:
for pred, target_ in zip(prediction, target): for pred, target_ in zip(prediction, target):
try: try:
pesq_values.append( pesq_values.append(
self.pesq(pred.squeeze(), target_.squeeze()).item() pesq(
self.sr,
target_.squeeze().detach().numpy(),
pred.squeeze().detach().numpy(),
self.mode,
)
) )
except Exception as e: except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ") logging.warning(f"{e} error occured while calculating PESQ")

View File

@ -5,7 +5,7 @@ joblib>=1.2.0
librosa>=0.9.2 librosa>=0.9.2
mlflow>=1.29.0 mlflow>=1.29.0
numpy>=1.23.3 numpy>=1.23.3
pesq==0.0.4 git+https://github.com/ludlows/python-pesq#egg=pesq
protobuf>=3.19.6 protobuf>=3.19.6
pystoi==0.3.3 pystoi==0.3.3
pytest-lazy-fixture>=0.6.3 pytest-lazy-fixture>=0.6.3