diff --git a/cli/train.py b/cli/train.py new file mode 100644 index 0000000..e6644e8 --- /dev/null +++ b/cli/train.py @@ -0,0 +1,35 @@ +import hydra +from hydra.core.config_store import ConfigStore +from hydra.utils import instantiate + +from omegaconf import DictConfig,OmegaConf +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.loggers import MLFlowLogger +import torch + + +from enhancer.models.demucs import Demucs +from enhancer.data.dataset import EnhancerDataset + + +@hydra.main(config_path="train_config",config_name="config") +def main(config: DictConfig): + + logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, + run_name=config.mlflow.run_name) + + + parameters = config.hyperparameters + + dataset = instantiate(config.dataset) + model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), + loss=parameters.get("loss"), metric = parameters.get("metric")) + + trainer = instantiate(config.trainer,logger=logger) + trainer.fit(model) + + + +if __name__=="__main__": + main() \ No newline at end of file diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml new file mode 100644 index 0000000..7845b01 --- /dev/null +++ b/cli/train_config/config.yaml @@ -0,0 +1,7 @@ +defaults: + - model : Demucs + - dataset : Vctk + - optimizer : Adam + - hyperparameters : default + - trainer : fastrun_dev + - mlflow : experiment \ No newline at end of file diff --git a/cli/train_config/dataset/DNS-2020.yaml b/cli/train_config/dataset/DNS-2020.yaml new file mode 100644 index 0000000..f59cb2b --- /dev/null +++ b/cli/train_config/dataset/DNS-2020.yaml @@ -0,0 +1,13 @@ +_target_: enhancer.data.dataset.EnhancerDataset +root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test +name : dns-2020 +duration : 1.0 +sampling_rate: 16000 +batch_size: 32 +files: + root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test + train_clean : clean_test_wav + test_clean : clean_test_wav + train_noisy : clean_test_wav + test_noisy : clean_test_wav + diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml new file mode 100644 index 0000000..1788177 --- /dev/null +++ b/cli/train_config/dataset/Vctk.yaml @@ -0,0 +1,15 @@ +_target_: enhancer.data.dataset.EnhancerDataset +root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk_test + +files: + train_clean : clean_test_wav + test_clean : clean_test_wav + train_noisy : clean_test_wav + test_noisy : clean_test_wav + +name : vctk +duration : 1.0 +sampling_rate: 48000 +batch_size: 32 + + diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml new file mode 100644 index 0000000..5cbdcb0 --- /dev/null +++ b/cli/train_config/hyperparameters/default.yaml @@ -0,0 +1,4 @@ +loss : mse +metric : mae +lr : 0.001 +num_epochs : 10 diff --git a/cli/train_config/mlflow/experiment.yaml b/cli/train_config/mlflow/experiment.yaml new file mode 100644 index 0000000..b64b125 --- /dev/null +++ b/cli/train_config/mlflow/experiment.yaml @@ -0,0 +1,2 @@ +experiment_name : "myexp" +run_name : "myrun" \ No newline at end of file diff --git a/cli/train_config/model/Demucs.yaml b/cli/train_config/model/Demucs.yaml new file mode 100644 index 0000000..27603dc --- /dev/null +++ b/cli/train_config/model/Demucs.yaml @@ -0,0 +1,18 @@ +_target_: enhancer.models.demucs.Demucs +num_channels: 1 +resample: 4 +sampling_rate : 16000 + +encoder_decoder: + depth: 5 + initial_output_channels: 48 + kernel_size: 8 + stride: 1 + growth_factor: 2 + glu: True + +lstm: + bidirectional: False + num_layers: 2 + + diff --git a/cli/train_config/model/WaveUnet.yaml b/cli/train_config/model/WaveUnet.yaml new file mode 100644 index 0000000..d641bcd --- /dev/null +++ b/cli/train_config/model/WaveUnet.yaml @@ -0,0 +1,5 @@ +_target_: enhancer.models.waveunet.WaveUnet +num_channels : 1 +depth : 12 +initial_output_channels: 24 +sampling_rate : 16000 diff --git a/cli/train_config/optimizer/Adam.yaml b/cli/train_config/optimizer/Adam.yaml new file mode 100644 index 0000000..7952b81 --- /dev/null +++ b/cli/train_config/optimizer/Adam.yaml @@ -0,0 +1,6 @@ +_target_: torch.optim.Adam +lr: 1e-3 +betas: [0.9, 0.999] +eps: 1e-08 +weight_decay: 0 +amsgrad: False diff --git a/cli/train_config/trainer/default.yml b/cli/train_config/trainer/default.yml new file mode 100644 index 0000000..eeb5b85 --- /dev/null +++ b/cli/train_config/trainer/default.yml @@ -0,0 +1,47 @@ +# @package _group_ +_target_: pytorch_lightning.Trainer +accelerator: auto +accumulate_grad_batches: 1 +amp_backend: native +auto_lr_find: False +auto_scale_batch_size: False +auto_select_gpus: True +benchmark: False +check_val_every_n_epoch: 1 +detect_anomaly: False +deterministic: False +devices: auto +enable_checkpointing: True +enable_model_summary: True +enable_progress_bar: True +fast_dev_run: False +gpus: null +gradient_clip_val: 0 +gradient_clip_algorithm: norm +ipus: null +limit_predict_batches: 1.0 +limit_test_batches: 1.0 +limit_train_batches: 1.0 +limit_val_batches: 1.0 +log_every_n_steps: 50 +max_epochs: 1000 +max_steps: null +max_time: null +min_epochs: 1 +min_steps: null +move_metrics_to_cpu: False +multiple_trainloader_mode: max_size_cycle +num_nodes: 1 +num_processes: 1 +num_sanity_val_steps: 2 +overfit_batches: 0.0 +precision: 32 +profiler: null +reload_dataloaders_every_n_epochs: 0 +replace_sampler_ddp: True +strategy: null +sync_batchnorm: False +tpu_cores: null +track_grad_norm: -1 +val_check_interval: 1.0 +weights_save_path: null diff --git a/cli/train_config/trainer/fastrun_dev.yaml b/cli/train_config/trainer/fastrun_dev.yaml new file mode 100644 index 0000000..5d0895f --- /dev/null +++ b/cli/train_config/trainer/fastrun_dev.yaml @@ -0,0 +1,3 @@ +# @package _group_ +_target_: pytorch_lightning.Trainer +fast_dev_run: True