configure cli/default arguments
This commit is contained in:
parent
b742756311
commit
e4db841ebb
|
|
@ -0,0 +1,35 @@
|
||||||
|
import hydra
|
||||||
|
from hydra.core.config_store import ConfigStore
|
||||||
|
from hydra.utils import instantiate
|
||||||
|
|
||||||
|
from omegaconf import DictConfig,OmegaConf
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from enhancer.models.demucs import Demucs
|
||||||
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
|
def main(config: DictConfig):
|
||||||
|
|
||||||
|
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
||||||
|
run_name=config.mlflow.run_name)
|
||||||
|
|
||||||
|
|
||||||
|
parameters = config.hyperparameters
|
||||||
|
|
||||||
|
dataset = instantiate(config.dataset)
|
||||||
|
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
|
||||||
|
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||||
|
|
||||||
|
trainer = instantiate(config.trainer,logger=logger)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
defaults:
|
||||||
|
- model : Demucs
|
||||||
|
- dataset : Vctk
|
||||||
|
- optimizer : Adam
|
||||||
|
- hyperparameters : default
|
||||||
|
- trainer : fastrun_dev
|
||||||
|
- mlflow : experiment
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
|
||||||
|
name : dns-2020
|
||||||
|
duration : 1.0
|
||||||
|
sampling_rate: 16000
|
||||||
|
batch_size: 32
|
||||||
|
files:
|
||||||
|
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
|
||||||
|
train_clean : clean_test_wav
|
||||||
|
test_clean : clean_test_wav
|
||||||
|
train_noisy : clean_test_wav
|
||||||
|
test_noisy : clean_test_wav
|
||||||
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
_target_: enhancer.data.dataset.EnhancerDataset
|
||||||
|
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test
|
||||||
|
|
||||||
|
files:
|
||||||
|
train_clean : clean_test_wav
|
||||||
|
test_clean : clean_test_wav
|
||||||
|
train_noisy : clean_test_wav
|
||||||
|
test_noisy : clean_test_wav
|
||||||
|
|
||||||
|
name : vctk
|
||||||
|
duration : 1.0
|
||||||
|
sampling_rate: 48000
|
||||||
|
batch_size: 32
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
loss : mse
|
||||||
|
metric : mae
|
||||||
|
lr : 0.001
|
||||||
|
num_epochs : 10
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
experiment_name : "myexp"
|
||||||
|
run_name : "myrun"
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
_target_: enhancer.models.demucs.Demucs
|
||||||
|
num_channels: 1
|
||||||
|
resample: 4
|
||||||
|
sampling_rate : 16000
|
||||||
|
|
||||||
|
encoder_decoder:
|
||||||
|
depth: 5
|
||||||
|
initial_output_channels: 48
|
||||||
|
kernel_size: 8
|
||||||
|
stride: 1
|
||||||
|
growth_factor: 2
|
||||||
|
glu: True
|
||||||
|
|
||||||
|
lstm:
|
||||||
|
bidirectional: False
|
||||||
|
num_layers: 2
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
_target_: enhancer.models.waveunet.WaveUnet
|
||||||
|
num_channels : 1
|
||||||
|
depth : 12
|
||||||
|
initial_output_channels: 24
|
||||||
|
sampling_rate : 16000
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
_target_: torch.optim.Adam
|
||||||
|
lr: 1e-3
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
eps: 1e-08
|
||||||
|
weight_decay: 0
|
||||||
|
amsgrad: False
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
# @package _group_
|
||||||
|
_target_: pytorch_lightning.Trainer
|
||||||
|
accelerator: auto
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
amp_backend: native
|
||||||
|
auto_lr_find: False
|
||||||
|
auto_scale_batch_size: False
|
||||||
|
auto_select_gpus: True
|
||||||
|
benchmark: False
|
||||||
|
check_val_every_n_epoch: 1
|
||||||
|
detect_anomaly: False
|
||||||
|
deterministic: False
|
||||||
|
devices: auto
|
||||||
|
enable_checkpointing: True
|
||||||
|
enable_model_summary: True
|
||||||
|
enable_progress_bar: True
|
||||||
|
fast_dev_run: False
|
||||||
|
gpus: null
|
||||||
|
gradient_clip_val: 0
|
||||||
|
gradient_clip_algorithm: norm
|
||||||
|
ipus: null
|
||||||
|
limit_predict_batches: 1.0
|
||||||
|
limit_test_batches: 1.0
|
||||||
|
limit_train_batches: 1.0
|
||||||
|
limit_val_batches: 1.0
|
||||||
|
log_every_n_steps: 50
|
||||||
|
max_epochs: 1000
|
||||||
|
max_steps: null
|
||||||
|
max_time: null
|
||||||
|
min_epochs: 1
|
||||||
|
min_steps: null
|
||||||
|
move_metrics_to_cpu: False
|
||||||
|
multiple_trainloader_mode: max_size_cycle
|
||||||
|
num_nodes: 1
|
||||||
|
num_processes: 1
|
||||||
|
num_sanity_val_steps: 2
|
||||||
|
overfit_batches: 0.0
|
||||||
|
precision: 32
|
||||||
|
profiler: null
|
||||||
|
reload_dataloaders_every_n_epochs: 0
|
||||||
|
replace_sampler_ddp: True
|
||||||
|
strategy: null
|
||||||
|
sync_batchnorm: False
|
||||||
|
tpu_cores: null
|
||||||
|
track_grad_norm: -1
|
||||||
|
val_check_interval: 1.0
|
||||||
|
weights_save_path: null
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
# @package _group_
|
||||||
|
_target_: pytorch_lightning.Trainer
|
||||||
|
fast_dev_run: True
|
||||||
Loading…
Reference in New Issue