diff --git a/cli/train.py b/cli/train.py index d951f33..7d49acc 100644 --- a/cli/train.py +++ b/cli/train.py @@ -33,7 +33,7 @@ def main(config: DictConfig): monitor="val_loss", mode=direction, min_delta=0.0, - patience=10, + patience=parameters.get("EarlyStopping_patience",10), strict=True, verbose=False, ) @@ -44,10 +44,10 @@ def main(config: DictConfig): scheduler = ReduceLROnPlateau( optimizer=optimizer, mode=direction, - factor=0.1, + factor=parameters.get("ReduceLr_factor",0.1), verbose=True, min_lr=parameters.get("min_lr",1e-6), - patience=3 + patience=parameters.get("ReduceLr_patience",3) ) return {"optimizer":optimizer, "lr_scheduler":scheduler}