merge dev

This commit is contained in:
shahules786 2022-10-27 15:23:17 +05:30
commit dbfa580618
12 changed files with 128 additions and 99 deletions

View File

@ -4,10 +4,16 @@ from types import MethodType
import hydra import hydra
from hydra.utils import instantiate from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1" os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0") JOB_ID = os.environ.get("SLURM_JOBID", "0")
@ -25,8 +31,13 @@ def main(config: DictConfig):
) )
parameters = config.hyperparameters parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset) dataset = instantiate(config.dataset, augmentations=None)
model = instantiate( model = instantiate(
config.model, config.model,
dataset=dataset, dataset=dataset,
@ -45,6 +56,8 @@ def main(config: DictConfig):
every_n_epochs=1, every_n_epochs=1,
) )
callbacks.append(checkpoint) callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False): if parameters.get("Early_stop", False):
early_stopping = EarlyStopping( early_stopping = EarlyStopping(
monitor=f"valid_{parameters.get('EarlyStopping_metric','loss')}", monitor=f"valid_{parameters.get('EarlyStopping_metric','loss')}",
@ -56,11 +69,11 @@ def main(config: DictConfig):
) )
callbacks.append(early_stopping) callbacks.append(early_stopping)
def configure_optimizer(self): def configure_optimizers(self):
optimizer = instantiate( optimizer = instantiate(
config.optimizer, config.optimizer,
lr=parameters.get("lr"), lr=parameters.get("lr"),
parameters=self.parameters(), params=self.parameters(),
) )
scheduler = ReduceLROnPlateau( scheduler = ReduceLROnPlateau(
optimizer=optimizer, optimizer=optimizer,
@ -70,9 +83,13 @@ def main(config: DictConfig):
min_lr=parameters.get("min_lr", 1e-6), min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3), patience=parameters.get("ReduceLr_patience", 3),
) )
return {"optimizer": optimizer, "lr_scheduler": scheduler} return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_parameters = MethodType(configure_optimizer, model) model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model) trainer.fit(model)

View File

@ -2,10 +2,10 @@ _target_: enhancer.data.dataset.EnhancerDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5 duration : 4.5
stride : 2
sampling_rate: 16000 sampling_rate: 16000
batch_size: 128 batch_size: 32
valid_minutes : 10 valid_minutes : 15
files: files:
train_clean : clean_trainset_28spk_wav train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav test_clean : clean_testset_wav

View File

@ -1,8 +1,7 @@
loss : mse loss : mae
metric : [stoi,pesq,si-sdr] metric : [stoi,pesq,si-sdr]
lr : 0.0003 lr : 0.0003
ReduceLr_patience : 10 ReduceLr_patience : 5
ReduceLr_factor : 0.5 ReduceLr_factor : 0.2
min_lr : 0.00 min_lr : 0.000001
early_stop : True
EarlyStopping_factor : 10 EarlyStopping_factor : 10

View File

@ -1,2 +1,2 @@
experiment_name : shahules/enhancer experiment_name : shahules/enhancer
run_name : baseline run_name : Demucs + Vtck with stride + augmentations

View File

@ -1,11 +1,11 @@
_target_: enhancer.models.demucs.Demucs _target_: enhancer.models.demucs.Demucs
num_channels: 1 num_channels: 1
resample: 2 resample: 4
sampling_rate : 16000 sampling_rate : 16000
encoder_decoder: encoder_decoder:
depth: 5 depth: 4
initial_output_channels: 32 initial_output_channels: 64
kernel_size: 8 kernel_size: 8
stride: 4 stride: 4
growth_factor: 2 growth_factor: 2

View File

@ -1,5 +1,5 @@
_target_: enhancer.models.waveunet.WaveUnet _target_: enhancer.models.waveunet.WaveUnet
num_channels : 1 num_channels : 1
depth : 12 depth : 9
initial_output_channels: 24 initial_output_channels: 24
sampling_rate : 16000 sampling_rate : 16000

View File

@ -9,7 +9,7 @@ benchmark: False
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
detect_anomaly: False detect_anomaly: False
deterministic: False deterministic: False
devices: 1 devices: 2
enable_checkpointing: True enable_checkpointing: True
enable_model_summary: True enable_model_summary: True
enable_progress_bar: True enable_progress_bar: True
@ -22,9 +22,10 @@ limit_predict_batches: 1.0
limit_test_batches: 1.0 limit_test_batches: 1.0
limit_train_batches: 1.0 limit_train_batches: 1.0
limit_val_batches: 1.0 limit_val_batches: 1.0
log_every_n_steps: 100 log_every_n_steps: 50
max_epochs: 200 max_epochs: 200
max_time: 00:47:00:00 max_steps: -1
max_time: null
min_epochs: 1 min_epochs: 1
min_steps: null min_steps: null
move_metrics_to_cpu: False move_metrics_to_cpu: False
@ -37,7 +38,7 @@ precision: 32
profiler: null profiler: null
reload_dataloaders_every_n_epochs: 0 reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True replace_sampler_ddp: True
strategy: null strategy: ddp
sync_batchnorm: False sync_batchnorm: False
tpu_cores: null tpu_cores: null
track_grad_norm: -1 track_grad_norm: -1

View File

@ -1,14 +1,15 @@
import math import math
import multiprocessing import multiprocessing
import os import os
import random from pathlib import Path
from itertools import chain, cycle
from typing import Optional from typing import Optional
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data import DataLoader, Dataset
from torch_audiomentations import Compose
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils import check_files from enhancer.utils import check_files
@ -16,13 +17,15 @@ from enhancer.utils.config import Files
from enhancer.utils.io import Audio from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng from enhancer.utils.random import create_unique_rng
LARGE_NUM = 2147483647
class TrainDataset(IterableDataset):
class TrainDataset(Dataset):
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
def __iter__(self): def __getitem__(self, idx):
return self.dataset.train__iter__() return self.dataset.train__getitem__(idx)
def __len__(self): def __len__(self):
return self.dataset.train__len__() return self.dataset.train__len__()
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
): ):
super().__init__() super().__init__()
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
else: else:
raise ValueError("valid_minutes must be greater than 0") raise ValueError("valid_minutes must be greater than 0")
self.augmentations = augmentations
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
""" """
prepare train/validation/test data splits prepare train/validation/test data splits
@ -115,16 +121,29 @@ class TaskDataset(pl.LightningDataModule):
): ):
valid_minutes *= 60 valid_minutes *= 60
valid_min_now = 0.0 valid_sec_now = 0.0
valid_indices = [] valid_indices = []
random_indices = list(range(0, len(data))) all_speakers = np.unique(
rng = create_unique_rng(random_state) [
rng.shuffle(random_indices) (Path(file["clean"]).name.split("_")[0], file["duration"])
i = 0 for file in data
while valid_min_now <= valid_minutes: ]
valid_indices.append(random_indices[i]) )
valid_min_now += data[random_indices[i]]["duration"] possible_indices = list(range(0, len(all_speakers)))
i += 1 rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
file_indices = [
i
for i, file in enumerate(data)
if speaker_name == Path(file["clean"]).name.split("_")[0]
]
for i in file_indices:
valid_indices.append(i)
valid_sec_now += data[i]["duration"]
train_data = [ train_data = [
item for i, item in enumerate(data) if i not in valid_indices item for i, item in enumerate(data) if i not in valid_indices
@ -135,16 +154,11 @@ class TaskDataset(pl.LightningDataModule):
def prepare_traindata(self, data): def prepare_traindata(self, data):
train_data = [] train_data = []
for item in data: for item in data:
samples_metadata = []
clean, noisy, total_dur = item.values() clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments( num_segments = self.get_num_segments(
total_dur, self.duration, self.stride total_dur, self.duration, self.stride
) )
for index in range(num_segments): samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
start = index * self.stride
samples_metadata.append(
({"clean": clean, "noisy": noisy}, start)
)
train_data.append(samples_metadata) train_data.append(samples_metadata)
return train_data return train_data
@ -166,7 +180,9 @@ class TaskDataset(pl.LightningDataModule):
if total_dur < self.duration: if total_dur < self.duration:
metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
else: else:
num_segments = round(total_dur / self.duration) num_segments = self.get_num_segments(
total_dur, self.duration, self.duration
)
for index in range(num_segments): for index in range(num_segments):
start_time = index * self.duration start_time = index * self.duration
metadata.append( metadata.append(
@ -175,31 +191,44 @@ class TaskDataset(pl.LightningDataModule):
return metadata return metadata
def train_collatefn(self, batch): def train_collatefn(self, batch):
output = {"noisy": [], "clean": []}
output = {"clean": [], "noisy": []}
for item in batch: for item in batch:
output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"]) output["clean"].append(item["clean"])
output["noisy"].append(item["noisy"])
output["clean"] = torch.stack(output["clean"], dim=0) output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0) output["noisy"] = torch.stack(output["noisy"], dim=0)
if self.augmentations is not None:
noise = output["noisy"] - output["clean"]
output["clean"] = self.augmentations(
output["clean"], sample_rate=self.sampling_rate
)
self.augmentations.freeze_parameters()
output["noisy"] = (
self.augmentations(noise, sample_rate=self.sampling_rate)
+ output["clean"]
)
return output return output
def worker_init_fn(self, _): @property
worker_info = torch.utils.data.get_worker_info() def generator(self):
dataset = worker_info.dataset generator = torch.Generator()
worker_id = worker_info.id if hasattr(self, "model"):
split_size = len(dataset.dataset.train_data) // worker_info.num_workers seed = self.model.current_epoch + LARGE_NUM
dataset.data = dataset.dataset.train_data[ else:
worker_id * split_size : (worker_id + 1) * split_size seed = LARGE_NUM
] return generator.manual_seed(seed)
def train_dataloader(self): def train_dataloader(self):
return DataLoader( return DataLoader(
TrainDataset(self), TrainDataset(self),
batch_size=None, batch_size=self.batch_size,
num_workers=self.num_workers, num_workers=self.num_workers,
generator=self.generator,
collate_fn=self.train_collatefn, collate_fn=self.train_collatefn,
worker_init_fn=self.worker_init_fn,
) )
def val_dataloader(self): def val_dataloader(self):
@ -256,6 +285,7 @@ class EnhancerDataset(TaskDataset):
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
augmentations: Optional[Compose] = None,
): ):
super().__init__( super().__init__(
@ -268,6 +298,7 @@ class EnhancerDataset(TaskDataset):
matching_function=matching_function, matching_function=matching_function,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
augmentations=augmentations,
) )
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
@ -280,35 +311,17 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage) super().setup(stage=stage)
def random_sample(self, train_data): def train__getitem__(self, idx):
return random.sample(train_data, len(train_data))
def train__iter__(self): for filedict, num_samples in self.train_data:
rng = create_unique_rng(self.model.current_epoch) if idx >= num_samples:
train_data = rng.sample(self.train_data, len(self.train_data)) idx -= num_samples
return zip( continue
*[ else:
self.get_stream(self.random_sample(train_data)) start = 0
for i in range(self.batch_size) if self.duration is not None:
] start = idx * self.stride
) return self.prepare_segment(filedict, start)
def get_stream(self, data):
return chain.from_iterable(map(self.process_data, cycle(data)))
def process_data(self, data):
for item in data:
yield self.prepare_segment(*item)
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
def val__getitem__(self, idx): def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx]) return self.prepare_segment(*self._validation[idx])
@ -348,7 +361,8 @@ class EnhancerDataset(TaskDataset):
} }
def train__len__(self): def train__len__(self):
return sum([len(item) for item in self.train_data]) // (self.batch_size) _, num_examples = list(zip(*self.train_data))
return sum(num_examples)
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)

View File

@ -3,7 +3,7 @@ import logging
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from pesq import pesq from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
@ -66,7 +66,7 @@ class Si_SDR:
"Invalid reduction, valid options are sum, mean, None" "Invalid reduction, valid options are sum, mean, None"
) )
self.higher_better = False self.higher_better = False
self.name = "Si-SDR" self.name = "si-sdr"
def __call__(self, prediction: torch.Tensor, target: torch.Tensor): def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
@ -122,20 +122,16 @@ class Pesq:
self.sr = sr self.sr = sr
self.name = "pesq" self.name = "pesq"
self.mode = mode self.mode = mode
self.pesq = PerceptualEvaluationSpeechQuality(
fs=self.sr, mode=self.mode
)
def __call__(self, prediction: torch.Tensor, target: torch.Tensor): def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
pesq_values = [] pesq_values = []
for pred, target_ in zip(prediction, target): for pred, target_ in zip(prediction, target):
try: try:
pesq_values.append( pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
pesq(
self.sr,
target_.squeeze().detach().cpu().numpy(),
pred.squeeze().detach().cpu().numpy(),
self.mode,
)
)
except Exception as e: except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ") logging.warning(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values)) return torch.tensor(np.mean(pesq_values))

View File

@ -113,6 +113,8 @@ class Model(pl.LightningModule):
if stage == "fit": if stage == "fit":
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.dataset.setup(stage) self.dataset.setup(stage)
self.dataset.model = self
print( print(
"Total train duration", "Total train duration",
self.dataset.train_dataloader().dataset.__len__() self.dataset.train_dataloader().dataset.__len__()
@ -134,7 +136,6 @@ class Model(pl.LightningModule):
/ 60, / 60,
"minutes", "minutes",
) )
self.dataset.model = self
def train_dataloader(self): def train_dataloader(self):
return self.dataset.train_dataloader() return self.dataset.train_dataloader()

View File

@ -70,7 +70,7 @@ class Audio:
if sampling_rate: if sampling_rate:
audio = self.__class__.resample_audio( audio = self.__class__.resample_audio(
audio, self.sampling_rate, sampling_rate audio, sampling_rate, self.sampling_rate
) )
if self.return_tensor: if self.return_tensor:
return torch.tensor(audio) return torch.tensor(audio)

View File

@ -5,7 +5,7 @@ joblib>=1.2.0
librosa>=0.9.2 librosa>=0.9.2
mlflow>=1.29.0 mlflow>=1.29.0
numpy>=1.23.3 numpy>=1.23.3
git+https://github.com/ludlows/python-pesq#egg=pesq pesq==0.0.4
protobuf>=3.19.6 protobuf>=3.19.6
pystoi==0.3.3 pystoi==0.3.3
pytest-lazy-fixture>=0.6.3 pytest-lazy-fixture>=0.6.3
@ -14,5 +14,6 @@ scikit-learn>=1.1.2
scipy>=1.9.1 scipy>=1.9.1
soundfile>=0.11.0 soundfile>=0.11.0
torch>=1.12.1 torch>=1.12.1
torch-audiomentations==0.11.0
torchaudio>=0.12.1 torchaudio>=0.12.1
tqdm>=4.64.1 tqdm>=4.64.1