merge dev

This commit is contained in:
shahules786 2022-10-27 15:23:17 +05:30
commit dbfa580618
12 changed files with 128 additions and 99 deletions

View File

@ -4,10 +4,16 @@ from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@ -25,8 +31,13 @@ def main(config: DictConfig):
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset)
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
@ -45,6 +56,8 @@ def main(config: DictConfig):
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor=f"valid_{parameters.get('EarlyStopping_metric','loss')}",
@ -56,11 +69,11 @@ def main(config: DictConfig):
)
callbacks.append(early_stopping)
def configure_optimizer(self):
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
parameters=self.parameters(),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
@ -70,9 +83,13 @@ def main(config: DictConfig):
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_parameters = MethodType(configure_optimizer, model)
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)

View File

@ -2,10 +2,10 @@ _target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5
stride : 2
sampling_rate: 16000
batch_size: 128
valid_minutes : 10
batch_size: 32
valid_minutes : 15
files:
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav

View File

@ -1,8 +1,7 @@
loss : mse
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.0003
ReduceLr_patience : 10
ReduceLr_factor : 0.5
min_lr : 0.00
early_stop : True
ReduceLr_patience : 5
ReduceLr_factor : 0.2
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +1,2 @@
experiment_name : shahules/enhancer
run_name : baseline
run_name : Demucs + Vtck with stride + augmentations

View File

@ -1,11 +1,11 @@
_target_: enhancer.models.demucs.Demucs
num_channels: 1
resample: 2
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 5
initial_output_channels: 32
depth: 4
initial_output_channels: 64
kernel_size: 8
stride: 4
growth_factor: 2

View File

@ -1,5 +1,5 @@
_target_: enhancer.models.waveunet.WaveUnet
num_channels : 1
depth : 12
depth : 9
initial_output_channels: 24
sampling_rate : 16000

View File

@ -9,7 +9,7 @@ benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 1
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
@ -22,9 +22,10 @@ limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 100
log_every_n_steps: 50
max_epochs: 200
max_time: 00:47:00:00
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
@ -37,7 +38,7 @@ precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: null
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1

View File

@ -1,14 +1,15 @@
import math
import multiprocessing
import os
import random
from itertools import chain, cycle
from pathlib import Path
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.data import DataLoader, Dataset
from torch_audiomentations import Compose
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils import check_files
@ -16,13 +17,15 @@ from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
LARGE_NUM = 2147483647
class TrainDataset(IterableDataset):
class TrainDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
return self.dataset.train__iter__()
def __getitem__(self, idx):
return self.dataset.train__getitem__(idx)
def __len__(self):
return self.dataset.train__len__()
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
):
super().__init__()
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
else:
raise ValueError("valid_minutes must be greater than 0")
self.augmentations = augmentations
def setup(self, stage: Optional[str] = None):
"""
prepare train/validation/test data splits
@ -115,16 +121,29 @@ class TaskDataset(pl.LightningDataModule):
):
valid_minutes *= 60
valid_min_now = 0.0
valid_sec_now = 0.0
valid_indices = []
random_indices = list(range(0, len(data)))
rng = create_unique_rng(random_state)
rng.shuffle(random_indices)
i = 0
while valid_min_now <= valid_minutes:
valid_indices.append(random_indices[i])
valid_min_now += data[random_indices[i]]["duration"]
i += 1
all_speakers = np.unique(
[
(Path(file["clean"]).name.split("_")[0], file["duration"])
for file in data
]
)
possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
file_indices = [
i
for i, file in enumerate(data)
if speaker_name == Path(file["clean"]).name.split("_")[0]
]
for i in file_indices:
valid_indices.append(i)
valid_sec_now += data[i]["duration"]
train_data = [
item for i, item in enumerate(data) if i not in valid_indices
@ -135,16 +154,11 @@ class TaskDataset(pl.LightningDataModule):
def prepare_traindata(self, data):
train_data = []
for item in data:
samples_metadata = []
clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments(
total_dur, self.duration, self.stride
)
for index in range(num_segments):
start = index * self.stride
samples_metadata.append(
({"clean": clean, "noisy": noisy}, start)
)
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
train_data.append(samples_metadata)
return train_data
@ -166,7 +180,9 @@ class TaskDataset(pl.LightningDataModule):
if total_dur < self.duration:
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
else:
num_segments = round(total_dur / self.duration)
num_segments = self.get_num_segments(
total_dur, self.duration, self.duration
)
for index in range(num_segments):
start_time = index * self.duration
metadata.append(
@ -175,31 +191,44 @@ class TaskDataset(pl.LightningDataModule):
return metadata
def train_collatefn(self, batch):
output = {"noisy": [], "clean": []}
output = {"clean": [], "noisy": []}
for item in batch:
output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"])
output["noisy"].append(item["noisy"])
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
if self.augmentations is not None:
noise = output["noisy"] - output["clean"]
output["clean"] = self.augmentations(
output["clean"], sample_rate=self.sampling_rate
)
self.augmentations.freeze_parameters()
output["noisy"] = (
self.augmentations(noise, sample_rate=self.sampling_rate)
+ output["clean"]
)
return output
def worker_init_fn(self, _):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
dataset.data = dataset.dataset.train_data[
worker_id * split_size : (worker_id + 1) * split_size
]
@property
def generator(self):
generator = torch.Generator()
if hasattr(self, "model"):
seed = self.model.current_epoch + LARGE_NUM
else:
seed = LARGE_NUM
return generator.manual_seed(seed)
def train_dataloader(self):
return DataLoader(
TrainDataset(self),
batch_size=None,
batch_size=self.batch_size,
num_workers=self.num_workers,
generator=self.generator,
collate_fn=self.train_collatefn,
worker_init_fn=self.worker_init_fn,
)
def val_dataloader(self):
@ -256,6 +285,7 @@ class EnhancerDataset(TaskDataset):
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
):
super().__init__(
@ -268,6 +298,7 @@ class EnhancerDataset(TaskDataset):
matching_function=matching_function,
batch_size=batch_size,
num_workers=num_workers,
augmentations=augmentations,
)
self.sampling_rate = sampling_rate
@ -280,35 +311,17 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage)
def random_sample(self, train_data):
return random.sample(train_data, len(train_data))
def train__getitem__(self, idx):
def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
train_data = rng.sample(self.train_data, len(self.train_data))
return zip(
*[
self.get_stream(self.random_sample(train_data))
for i in range(self.batch_size)
]
)
def get_stream(self, data):
return chain.from_iterable(map(self.process_data, cycle(data)))
def process_data(self, data):
for item in data:
yield self.prepare_segment(*item)
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
for filedict, num_samples in self.train_data:
if idx >= num_samples:
idx -= num_samples
continue
else:
start = 0
if self.duration is not None:
start = idx * self.stride
return self.prepare_segment(filedict, start)
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
@ -348,7 +361,8 @@ class EnhancerDataset(TaskDataset):
}
def train__len__(self):
return sum([len(item) for item in self.train_data]) // (self.batch_size)
_, num_examples = list(zip(*self.train_data))
return sum(num_examples)
def val__len__(self):
return len(self._validation)

View File

@ -3,7 +3,7 @@ import logging
import numpy as np
import torch
import torch.nn as nn
from pesq import pesq
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
@ -66,7 +66,7 @@ class Si_SDR:
"Invalid reduction, valid options are sum, mean, None"
)
self.higher_better = False
self.name = "Si-SDR"
self.name = "si-sdr"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
@ -122,20 +122,16 @@ class Pesq:
self.sr = sr
self.name = "pesq"
self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(
fs=self.sr, mode=self.mode
)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
pesq_values = []
for pred, target_ in zip(prediction, target):
try:
pesq_values.append(
pesq(
self.sr,
target_.squeeze().detach().cpu().numpy(),
pred.squeeze().detach().cpu().numpy(),
self.mode,
)
)
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values))

View File

@ -113,6 +113,8 @@ class Model(pl.LightningModule):
if stage == "fit":
torch.cuda.empty_cache()
self.dataset.setup(stage)
self.dataset.model = self
print(
"Total train duration",
self.dataset.train_dataloader().dataset.__len__()
@ -134,7 +136,6 @@ class Model(pl.LightningModule):
/ 60,
"minutes",
)
self.dataset.model = self
def train_dataloader(self):
return self.dataset.train_dataloader()

View File

@ -70,7 +70,7 @@ class Audio:
if sampling_rate:
audio = self.__class__.resample_audio(
audio, self.sampling_rate, sampling_rate
audio, sampling_rate, self.sampling_rate
)
if self.return_tensor:
return torch.tensor(audio)

View File

@ -5,7 +5,7 @@ joblib>=1.2.0
librosa>=0.9.2
mlflow>=1.29.0
numpy>=1.23.3
git+https://github.com/ludlows/python-pesq#egg=pesq
pesq==0.0.4
protobuf>=3.19.6
pystoi==0.3.3
pytest-lazy-fixture>=0.6.3
@ -14,5 +14,6 @@ scikit-learn>=1.1.2
scipy>=1.9.1
soundfile>=0.11.0
torch>=1.12.1
torch-audiomentations==0.11.0
torchaudio>=0.12.1
tqdm>=4.64.1