From 8ac01b846de688d4e060a3c52ed5106e8a8e63bd Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:54:10 +0530 Subject: [PATCH] black --- enhancer/cli/train.py | 79 ++++++++++++++++++++++++++----------------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index dee3d2e..814fa0f 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -1,4 +1,3 @@ -from genericpath import isfile import os from types import MethodType import hydra @@ -7,61 +6,79 @@ from omegaconf import DictConfig from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger -os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID","0") -@hydra.main(config_path="train_config",config_name="config") +os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID", "0") + + +@hydra.main(config_path="train_config", config_name="config") def main(config: DictConfig): callbacks = [] - logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) - + logger = MLFlowLogger( + experiment_name=config.mlflow.experiment_name, + run_name=config.mlflow.run_name, + tags={"JOB_ID": JOB_ID}, + ) parameters = config.hyperparameters dataset = instantiate(config.dataset) - model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), - loss=parameters.get("loss"), metric = parameters.get("metric")) + model = instantiate( + config.model, + dataset=dataset, + lr=parameters.get("lr"), + loss=parameters.get("loss"), + metric=parameters.get("metric"), + ) direction = model.valid_monitor checkpoint = ModelCheckpoint( - dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, - mode=direction,every_n_epochs=1 + dirpath="./model", + filename=f"model_{JOB_ID}", + monitor="val_loss", + verbose=True, + mode=direction, + every_n_epochs=1, ) callbacks.append(checkpoint) early_stopping = EarlyStopping( - monitor="val_loss", - mode=direction, - min_delta=0.0, - patience=parameters.get("EarlyStopping_patience",10), - strict=True, - verbose=False, - ) + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) callbacks.append(early_stopping) - + def configure_optimizer(self): - optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) + optimizer = instantiate( + config.optimizer, + lr=parameters.get("lr"), + parameters=self.parameters(), + ) scheduler = ReduceLROnPlateau( optimizer=optimizer, mode=direction, - factor=parameters.get("ReduceLr_factor",0.1), + factor=parameters.get("ReduceLr_factor", 0.1), verbose=True, - min_lr=parameters.get("min_lr",1e-6), - patience=parameters.get("ReduceLr_patience",3) + min_lr=parameters.get("min_lr", 1e-6), + patience=parameters.get("ReduceLr_patience", 3), ) - return {"optimizer":optimizer, "lr_scheduler":scheduler} + return {"optimizer": optimizer, "lr_scheduler": scheduler} - model.configure_parameters = MethodType(configure_optimizer,model) + model.configure_parameters = MethodType(configure_optimizer, model) - trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) + trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model) - saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") + saved_location = os.path.join( + trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" + ) if os.path.isfile(saved_location): - logger.experiment.log_artifact(logger.run_id,saved_location) + logger.experiment.log_artifact(logger.run_id, saved_location) - -if __name__=="__main__": - main() \ No newline at end of file +if __name__ == "__main__": + main()