Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
c9b78b0e73
|
|
@ -3,7 +3,7 @@ from types import MethodType
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
@ -15,12 +15,16 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
@hydra.main(config_path="train_config", config_name="config")
|
@hydra.main(config_path="train_config", config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
||||||
|
yaml_conf = OmegaConf.to_yaml(config)
|
||||||
|
OmegaConf.save(yaml_conf, "config_log.yaml")
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
logger = MLFlowLogger(
|
logger = MLFlowLogger(
|
||||||
experiment_name=config.mlflow.experiment_name,
|
experiment_name=config.mlflow.experiment_name,
|
||||||
run_name=config.mlflow.run_name,
|
run_name=config.mlflow.run_name,
|
||||||
tags={"JOB_ID": JOB_ID},
|
tags={"JOB_ID": JOB_ID},
|
||||||
)
|
)
|
||||||
|
logger.experiment.log_artifact(logger.run_id, "config_log.yaml")
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
parameters = config.hyperparameters
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue