configure cli/default arguments

This commit is contained in:
shahules786 2022-09-27 15:52:28 +05:30
parent b742756311
commit e4db841ebb
11 changed files with 155 additions and 0 deletions

35
cli/train.py Normal file
View File

@ -0,0 +1,35 @@
import hydra
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from omegaconf import DictConfig,OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
import torch
from enhancer.models.demucs import Demucs
from enhancer.data.dataset import EnhancerDataset
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name)
parameters = config.hyperparameters
dataset = instantiate(config.dataset)
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
loss=parameters.get("loss"), metric = parameters.get("metric"))
trainer = instantiate(config.trainer,logger=logger)
trainer.fit(model)
if __name__=="__main__":
main()

View File

@ -0,0 +1,7 @@
defaults:
- model : Demucs
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : fastrun_dev
- mlflow : experiment

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
name : dns-2020
duration : 1.0
sampling_rate: 16000
batch_size: 32
files:
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
train_clean : clean_test_wav
test_clean : clean_test_wav
train_noisy : clean_test_wav
test_noisy : clean_test_wav

View File

@ -0,0 +1,15 @@
_target_: enhancer.data.dataset.EnhancerDataset
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
files:
train_clean : clean_test_wav
test_clean : clean_test_wav
train_noisy : clean_test_wav
test_noisy : clean_test_wav
name : vctk
duration : 1.0
sampling_rate: 48000
batch_size: 32

View File

@ -0,0 +1,4 @@
loss : mse
metric : mae
lr : 0.001
num_epochs : 10

View File

@ -0,0 +1,2 @@
experiment_name : "myexp"
run_name : "myrun"

View File

@ -0,0 +1,18 @@
_target_: enhancer.models.demucs.Demucs
num_channels: 1
resample: 4
sampling_rate : 16000
encoder_decoder:
depth: 5
initial_output_channels: 48
kernel_size: 8
stride: 1
growth_factor: 2
glu: True
lstm:
bidirectional: False
num_layers: 2

View File

@ -0,0 +1,5 @@
_target_: enhancer.models.waveunet.WaveUnet
num_channels : 1
depth : 12
initial_output_channels: 24
sampling_rate : 16000

View File

@ -0,0 +1,6 @@
_target_: torch.optim.Adam
lr: 1e-3
betas: [0.9, 0.999]
eps: 1e-08
weight_decay: 0
amsgrad: False

View File

@ -0,0 +1,47 @@
# @package _group_
_target_: pytorch_lightning.Trainer
accelerator: auto
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: False
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: auto
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
fast_dev_run: False
gpus: null
gradient_clip_val: 0
gradient_clip_algorithm: norm
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 1000
max_steps: null
max_time: null
min_epochs: 1
min_steps: null
move_metrics_to_cpu: False
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: 1
num_sanity_val_steps: 2
overfit_batches: 0.0
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: True
strategy: null
sync_batchnorm: False
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
weights_save_path: null

View File

@ -0,0 +1,3 @@
# @package _group_
_target_: pytorch_lightning.Trainer
fast_dev_run: True