Merge pull request #17 from shahules786/dev-datafix
foolproof iteration
This commit is contained in:
commit
a1445b0a95
|
|
@ -4,10 +4,16 @@ from types import MethodType
|
|||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
)
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
# from torch_audiomentations import Compose, Shift
|
||||
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||
|
||||
|
|
@ -25,8 +31,13 @@ def main(config: DictConfig):
|
|||
)
|
||||
|
||||
parameters = config.hyperparameters
|
||||
# apply_augmentations = Compose(
|
||||
# [
|
||||
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||
# ]
|
||||
# )
|
||||
|
||||
dataset = instantiate(config.dataset)
|
||||
dataset = instantiate(config.dataset, augmentations=None)
|
||||
model = instantiate(
|
||||
config.model,
|
||||
dataset=dataset,
|
||||
|
|
@ -45,6 +56,8 @@ def main(config: DictConfig):
|
|||
every_n_epochs=1,
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
|
||||
|
||||
if parameters.get("Early_stop", False):
|
||||
early_stopping = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
|
|
@ -56,11 +69,11 @@ def main(config: DictConfig):
|
|||
)
|
||||
callbacks.append(early_stopping)
|
||||
|
||||
def configure_optimizer(self):
|
||||
def configure_optimizers(self):
|
||||
optimizer = instantiate(
|
||||
config.optimizer,
|
||||
lr=parameters.get("lr"),
|
||||
parameters=self.parameters(),
|
||||
params=self.parameters(),
|
||||
)
|
||||
scheduler = ReduceLROnPlateau(
|
||||
optimizer=optimizer,
|
||||
|
|
@ -70,9 +83,13 @@ def main(config: DictConfig):
|
|||
min_lr=parameters.get("min_lr", 1e-6),
|
||||
patience=parameters.get("ReduceLr_patience", 3),
|
||||
)
|
||||
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": scheduler,
|
||||
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
|
||||
}
|
||||
|
||||
model.configure_parameters = MethodType(configure_optimizer, model)
|
||||
model.configure_optimizers = MethodType(configure_optimizers, model)
|
||||
|
||||
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ _target_: enhancer.data.dataset.EnhancerDataset
|
|||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 4.5
|
||||
stride : 0.5
|
||||
stride : 2
|
||||
sampling_rate: 16000
|
||||
batch_size: 4
|
||||
valid_minutes : 1
|
||||
batch_size: 32
|
||||
valid_minutes : 15
|
||||
files:
|
||||
train_clean : clean_trainset_28spk_wav
|
||||
test_clean : clean_testset_wav
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
loss : mse
|
||||
metric : mae
|
||||
lr : 0.0001
|
||||
loss : mae
|
||||
metric : [stoi,pesq,si-sdr]
|
||||
lr : 0.0003
|
||||
ReduceLr_patience : 5
|
||||
ReduceLr_factor : 0.1
|
||||
ReduceLr_factor : 0.2
|
||||
min_lr : 0.000001
|
||||
EarlyStopping_factor : 10
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
experiment_name : shahules/enhancer
|
||||
run_name : baseline
|
||||
run_name : Demucs + Vtck with stride + augmentations
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
_target_: enhancer.models.demucs.Demucs
|
||||
num_channels: 1
|
||||
resample: 2
|
||||
resample: 4
|
||||
sampling_rate : 16000
|
||||
|
||||
encoder_decoder:
|
||||
depth: 5
|
||||
initial_output_channels: 32
|
||||
depth: 4
|
||||
initial_output_channels: 64
|
||||
kernel_size: 8
|
||||
stride: 4
|
||||
growth_factor: 2
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
_target_: enhancer.models.waveunet.WaveUnet
|
||||
num_channels : 1
|
||||
depth : 12
|
||||
depth : 9
|
||||
initial_output_channels: 24
|
||||
sampling_rate : 16000
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: auto
|
||||
accelerator: gpu
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: True
|
||||
|
|
@ -9,7 +9,7 @@ benchmark: False
|
|||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: 1
|
||||
devices: 2
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
|
|
@ -23,7 +23,7 @@ limit_test_batches: 1.0
|
|||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 3
|
||||
max_epochs: 200
|
||||
max_steps: -1
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
|
|
@ -38,7 +38,7 @@ precision: 32
|
|||
profiler: null
|
||||
reload_dataloaders_every_n_epochs: 0
|
||||
replace_sampler_ddp: True
|
||||
strategy: null
|
||||
strategy: ddp
|
||||
sync_batchnorm: False
|
||||
tpu_cores: null
|
||||
track_grad_norm: -1
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
from itertools import chain, cycle
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch_audiomentations import Compose
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
from enhancer.utils import check_files
|
||||
|
|
@ -16,13 +17,15 @@ from enhancer.utils.config import Files
|
|||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.random import create_unique_rng
|
||||
|
||||
LARGE_NUM = 2147483647
|
||||
|
||||
class TrainDataset(IterableDataset):
|
||||
|
||||
class TrainDataset(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __iter__(self):
|
||||
return self.dataset.train__iter__()
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset.train__getitem__(idx)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset.train__len__()
|
||||
|
|
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers: Optional[int] = None,
|
||||
augmentations: Optional[Compose] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
|
|||
else:
|
||||
raise ValueError("valid_minutes must be greater than 0")
|
||||
|
||||
self.augmentations = augmentations
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
"""
|
||||
prepare train/validation/test data splits
|
||||
|
|
@ -115,16 +121,29 @@ class TaskDataset(pl.LightningDataModule):
|
|||
):
|
||||
|
||||
valid_minutes *= 60
|
||||
valid_min_now = 0.0
|
||||
valid_sec_now = 0.0
|
||||
valid_indices = []
|
||||
random_indices = list(range(0, len(data)))
|
||||
rng = create_unique_rng(random_state)
|
||||
rng.shuffle(random_indices)
|
||||
i = 0
|
||||
while valid_min_now <= valid_minutes:
|
||||
valid_indices.append(random_indices[i])
|
||||
valid_min_now += data[random_indices[i]]["duration"]
|
||||
i += 1
|
||||
all_speakers = np.unique(
|
||||
[
|
||||
(Path(file["clean"]).name.split("_")[0], file["duration"])
|
||||
for file in data
|
||||
]
|
||||
)
|
||||
possible_indices = list(range(0, len(all_speakers)))
|
||||
rng = create_unique_rng(len(all_speakers))
|
||||
|
||||
while valid_sec_now <= valid_minutes:
|
||||
speaker_index = rng.choice(possible_indices)
|
||||
possible_indices.remove(speaker_index)
|
||||
speaker_name = all_speakers[speaker_index]
|
||||
file_indices = [
|
||||
i
|
||||
for i, file in enumerate(data)
|
||||
if speaker_name == Path(file["clean"]).name.split("_")[0]
|
||||
]
|
||||
for i in file_indices:
|
||||
valid_indices.append(i)
|
||||
valid_sec_now += data[i]["duration"]
|
||||
|
||||
train_data = [
|
||||
item for i, item in enumerate(data) if i not in valid_indices
|
||||
|
|
@ -135,16 +154,11 @@ class TaskDataset(pl.LightningDataModule):
|
|||
def prepare_traindata(self, data):
|
||||
train_data = []
|
||||
for item in data:
|
||||
samples_metadata = []
|
||||
clean, noisy, total_dur = item.values()
|
||||
num_segments = self.get_num_segments(
|
||||
total_dur, self.duration, self.stride
|
||||
)
|
||||
for index in range(num_segments):
|
||||
start = index * self.stride
|
||||
samples_metadata.append(
|
||||
({"clean": clean, "noisy": noisy}, start)
|
||||
)
|
||||
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
|
||||
train_data.append(samples_metadata)
|
||||
return train_data
|
||||
|
||||
|
|
@ -166,7 +180,9 @@ class TaskDataset(pl.LightningDataModule):
|
|||
if total_dur < self.duration:
|
||||
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
||||
else:
|
||||
num_segments = round(total_dur / self.duration)
|
||||
num_segments = self.get_num_segments(
|
||||
total_dur, self.duration, self.duration
|
||||
)
|
||||
for index in range(num_segments):
|
||||
start_time = index * self.duration
|
||||
metadata.append(
|
||||
|
|
@ -175,31 +191,44 @@ class TaskDataset(pl.LightningDataModule):
|
|||
return metadata
|
||||
|
||||
def train_collatefn(self, batch):
|
||||
output = {"noisy": [], "clean": []}
|
||||
|
||||
output = {"clean": [], "noisy": []}
|
||||
for item in batch:
|
||||
output["noisy"].append(item["noisy"])
|
||||
output["clean"].append(item["clean"])
|
||||
output["noisy"].append(item["noisy"])
|
||||
|
||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||
|
||||
if self.augmentations is not None:
|
||||
noise = output["noisy"] - output["clean"]
|
||||
output["clean"] = self.augmentations(
|
||||
output["clean"], sample_rate=self.sampling_rate
|
||||
)
|
||||
self.augmentations.freeze_parameters()
|
||||
output["noisy"] = (
|
||||
self.augmentations(noise, sample_rate=self.sampling_rate)
|
||||
+ output["clean"]
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def worker_init_fn(self, _):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
|
||||
dataset.data = dataset.dataset.train_data[
|
||||
worker_id * split_size : (worker_id + 1) * split_size
|
||||
]
|
||||
@property
|
||||
def generator(self):
|
||||
generator = torch.Generator()
|
||||
if hasattr(self, "model"):
|
||||
seed = self.model.current_epoch + LARGE_NUM
|
||||
else:
|
||||
seed = LARGE_NUM
|
||||
return generator.manual_seed(seed)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
TrainDataset(self),
|
||||
batch_size=None,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
generator=self.generator,
|
||||
collate_fn=self.train_collatefn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
|
|
@ -251,11 +280,12 @@ class EnhancerDataset(TaskDataset):
|
|||
files: Files,
|
||||
valid_minutes=5.0,
|
||||
duration=1.0,
|
||||
stride=0.5,
|
||||
stride=None,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers: Optional[int] = None,
|
||||
augmentations: Optional[Compose] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
|
|
@ -268,6 +298,7 @@ class EnhancerDataset(TaskDataset):
|
|||
matching_function=matching_function,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
augmentations=augmentations,
|
||||
)
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
|
|
@ -280,35 +311,17 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
super().setup(stage=stage)
|
||||
|
||||
def random_sample(self, train_data):
|
||||
return random.sample(train_data, len(train_data))
|
||||
def train__getitem__(self, idx):
|
||||
|
||||
def train__iter__(self):
|
||||
rng = create_unique_rng(self.model.current_epoch)
|
||||
train_data = rng.sample(self.train_data, len(self.train_data))
|
||||
return zip(
|
||||
*[
|
||||
self.get_stream(self.random_sample(train_data))
|
||||
for i in range(self.batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
def get_stream(self, data):
|
||||
return chain.from_iterable(map(self.process_data, cycle(data)))
|
||||
|
||||
def process_data(self, data):
|
||||
for item in data:
|
||||
yield self.prepare_segment(*item)
|
||||
|
||||
@staticmethod
|
||||
def get_num_segments(file_duration, duration, stride):
|
||||
|
||||
if file_duration < duration:
|
||||
num_segments = 1
|
||||
for filedict, num_samples in self.train_data:
|
||||
if idx >= num_samples:
|
||||
idx -= num_samples
|
||||
continue
|
||||
else:
|
||||
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
||||
|
||||
return num_segments
|
||||
start = 0
|
||||
if self.duration is not None:
|
||||
start = idx * self.stride
|
||||
return self.prepare_segment(filedict, start)
|
||||
|
||||
def val__getitem__(self, idx):
|
||||
return self.prepare_segment(*self._validation[idx])
|
||||
|
|
@ -348,7 +361,8 @@ class EnhancerDataset(TaskDataset):
|
|||
}
|
||||
|
||||
def train__len__(self):
|
||||
return sum([len(item) for item in self.train_data]) // (self.batch_size)
|
||||
_, num_examples = list(zip(*self.train_data))
|
||||
return sum(num_examples)
|
||||
|
||||
def val__len__(self):
|
||||
return len(self._validation)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pesq import pesq
|
||||
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
|
||||
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility
|
||||
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class Si_SDR:
|
|||
"Invalid reduction, valid options are sum, mean, None"
|
||||
)
|
||||
self.higher_better = False
|
||||
self.name = "Si-SDR"
|
||||
self.name = "si-sdr"
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
|
|
@ -122,20 +122,16 @@ class Pesq:
|
|||
self.sr = sr
|
||||
self.name = "pesq"
|
||||
self.mode = mode
|
||||
self.pesq = PerceptualEvaluationSpeechQuality(
|
||||
fs=self.sr, mode=self.mode
|
||||
)
|
||||
|
||||
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
pesq_values = []
|
||||
for pred, target_ in zip(prediction, target):
|
||||
try:
|
||||
pesq_values.append(
|
||||
pesq(
|
||||
self.sr,
|
||||
target_.squeeze().detach().cpu().numpy(),
|
||||
pred.squeeze().detach().cpu().numpy(),
|
||||
self.mode,
|
||||
)
|
||||
)
|
||||
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||
except Exception as e:
|
||||
logging.warning(f"{e} error occured while calculating PESQ")
|
||||
return torch.tensor(np.mean(pesq_values))
|
||||
|
|
|
|||
|
|
@ -113,6 +113,8 @@ class Model(pl.LightningModule):
|
|||
if stage == "fit":
|
||||
torch.cuda.empty_cache()
|
||||
self.dataset.setup(stage)
|
||||
self.dataset.model = self
|
||||
|
||||
print(
|
||||
"Total train duration",
|
||||
self.dataset.train_dataloader().dataset.__len__()
|
||||
|
|
@ -134,7 +136,6 @@ class Model(pl.LightningModule):
|
|||
/ 60,
|
||||
"minutes",
|
||||
)
|
||||
self.dataset.model = self
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataset.train_dataloader()
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class Audio:
|
|||
|
||||
if sampling_rate:
|
||||
audio = self.__class__.resample_audio(
|
||||
audio, self.sampling_rate, sampling_rate
|
||||
audio, sampling_rate, self.sampling_rate
|
||||
)
|
||||
if self.return_tensor:
|
||||
return torch.tensor(audio)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ joblib>=1.2.0
|
|||
librosa>=0.9.2
|
||||
mlflow>=1.29.0
|
||||
numpy>=1.23.3
|
||||
git+https://github.com/ludlows/python-pesq#egg=pesq
|
||||
pesq==0.0.4
|
||||
protobuf>=3.19.6
|
||||
pystoi==0.3.3
|
||||
pytest-lazy-fixture>=0.6.3
|
||||
|
|
@ -14,5 +14,6 @@ scikit-learn>=1.1.2
|
|||
scipy>=1.9.1
|
||||
soundfile>=0.11.0
|
||||
torch>=1.12.1
|
||||
torch-audiomentations==0.11.0
|
||||
torchaudio>=0.12.1
|
||||
tqdm>=4.64.1
|
||||
|
|
|
|||
Loading…
Reference in New Issue