Merge pull request #2 from shahules786/dev-hawk

include model artifacts logging
This commit is contained in:
Shahul ES 2022-10-02 09:04:24 +05:30 committed by GitHub
commit bf07cae301
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 6 deletions

View File

@ -5,13 +5,14 @@ from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID")
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")})
run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID})
parameters = config.hyperparameters
@ -22,7 +23,7 @@ def main(config: DictConfig):
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="",filename="model",monitor="val_loss",verbose=False,
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False,
mode=direction,every_n_epochs=1
)
callbacks.append(checkpoint)
@ -30,7 +31,7 @@ def main(config: DictConfig):
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=100,
patience=10,
strict=True,
verbose=False,
)
@ -38,6 +39,7 @@ def main(config: DictConfig):
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)
logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}")

View File

@ -1,5 +1,5 @@
defaults:
- model : WaveUnet
- model : Demucs
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default

View File

@ -1,11 +1,11 @@
_target_: enhancer.models.demucs.Demucs
num_channels: 1
resample: 4
resample: 2
sampling_rate : 16000
encoder_decoder:
depth: 5
initial_output_channels: 48
initial_output_channels: 32
kernel_size: 8
stride: 1
growth_factor: 2