This commit is contained in:
shahules786 2022-11-24 11:57:04 +05:30
parent 2de2c715ed
commit 249c535921
12 changed files with 0 additions and 261 deletions

View File

@ -1,120 +0,0 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
)
from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau
# from torch_audiomentations import Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
OmegaConf.save(config, "config_log.yaml")
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
# apply_augmentations = Compose(
# [
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
# ]
# )
dataset = instantiate(config.dataset, augmentations=None)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="valid_loss",
verbose=False,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
callbacks.append(LearningRateMonitor(logging_interval="epoch"))
if parameters.get("Early_stop", False):
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizers(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
params=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}',
}
model.configure_optimizers = MethodType(configure_optimizers, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
logger.experiment.log_param(
logger.run_id,
"num_train_steps_per_epoch",
dataset.train__len__() / dataset.batch_size,
)
logger.experiment.log_param(
logger.run_id,
"num_valid_steps_per_epoch",
dataset.val__len__() / dataset.batch_size,
)
if __name__ == "__main__":
main()

View File

@ -1,7 +0,0 @@
defaults:
- model : Demucs
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : default
- mlflow : experiment

View File

@ -1,12 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
root_dir : /Users/shahules/Myprojects/MS-SNSD
name : MS-SNSD
duration : 2.0
sampling_rate: 16000
batch_size: 32
valid_size: 0.05
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training
train_noisy : NoisySpeech_training
test_noisy : NoisySpeech_training

View File

@ -1,13 +0,0 @@
_target_: mayavoz.data.dataset.MayaDataset
name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 4.5
stride : 2
sampling_rate: 16000
batch_size: 32
valid_minutes : 15
files:
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -1,7 +0,0 @@
loss : mae
metric : [stoi,pesq,si-sdr]
lr : 0.0003
ReduceLr_patience : 5
ReduceLr_factor : 0.2
min_lr : 0.000001
EarlyStopping_factor : 10

View File

@ -1,2 +0,0 @@
experiment_name : shahules/mayavoz
run_name : Demucs + Vtck with stride + augmentations

View File

@ -1,25 +0,0 @@
_target_: mayavoz.models.dccrn.DCCRN
num_channels: 1
sampling_rate : 16000
complex_lstm : True
complex_norm : True
complex_relu : True
masking_mode : True
encoder_decoder:
initial_output_channels : 32
depth : 6
kernel_size : 5
growth_factor : 2
stride : 2
padding : 2
output_padding : 1
lstm:
num_layers : 2
hidden_size : 256
stft:
window_len : 400
hop_size : 100
nfft : 512

View File

@ -1,16 +0,0 @@
_target_: mayavoz.models.demucs.Demucs
num_channels: 1
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 4
initial_output_channels: 64
kernel_size: 8
stride: 4
growth_factor: 2
glu: True
lstm:
bidirectional: False
num_layers: 2

View File

@ -1,5 +0,0 @@
_target_: mayavoz.models.waveunet.WaveUnet
num_channels : 1
depth : 9
initial_output_channels: 24
sampling_rate : 16000

View File

@ -1,6 +0,0 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -1,46 +0,0 @@
_target_: pytorch_lightning.Trainer
accelerator: gpu
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: 2
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 200
max_steps: -1
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: ddp
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -1,2 +0,0 @@
_target_: pytorch_lightning.Trainer
fast_dev_run: True