diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index 602cbc2..91caa91 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -1,5 +1,5 @@ _target_: pytorch_lightning.Trainer -accelerator: auto +accelerator: gpu accumulate_grad_batches: 1 amp_backend: native auto_lr_find: False @@ -9,7 +9,7 @@ benchmark: False check_val_every_n_epoch: 1 detect_anomaly: False deterministic: False -devices: 2 +devices: -1 enable_checkpointing: True enable_model_summary: True enable_progress_bar: True