merge & test
This commit is contained in:
commit
b0483d2fa8
23
cli/train.py
23
cli/train.py
|
|
@ -1,7 +1,9 @@
|
||||||
import os
|
import os
|
||||||
|
from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
|
|
@ -23,7 +25,7 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
direction = model.valid_monitor
|
direction = model.valid_monitor
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=False,
|
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True,
|
||||||
mode=direction,every_n_epochs=1
|
mode=direction,every_n_epochs=1
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
|
|
@ -31,15 +33,30 @@ def main(config: DictConfig):
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
mode=direction,
|
mode=direction,
|
||||||
min_delta=0.0,
|
min_delta=0.0,
|
||||||
patience=10,
|
patience=parameters.get("EarlyStopping_patience",10),
|
||||||
strict=True,
|
strict=True,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
callbacks.append(early_stopping)
|
callbacks.append(early_stopping)
|
||||||
|
|
||||||
|
def configure_optimizer(self):
|
||||||
|
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
|
||||||
|
scheduler = ReduceLROnPlateau(
|
||||||
|
optimizer=optimizer,
|
||||||
|
mode=direction,
|
||||||
|
factor=parameters.get("ReduceLr_factor",0.1),
|
||||||
|
verbose=True,
|
||||||
|
min_lr=parameters.get("min_lr",1e-6),
|
||||||
|
patience=parameters.get("ReduceLr_patience",3)
|
||||||
|
)
|
||||||
|
return {"optimizer":optimizer, "lr_scheduler":scheduler}
|
||||||
|
|
||||||
|
model.configure_parameters = MethodType(configure_optimizer,model)
|
||||||
|
|
||||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
logger.experiment.log_artifact(logger.run_id,f"./model/model_{JOB_ID}")
|
if os.path.exists("./model/"):
|
||||||
|
logger.experiment.log_artifact(logger.run_id,f"./model/.*")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
defaults:
|
defaults:
|
||||||
- model : Demucs
|
- model : WaveUnet
|
||||||
- dataset : Vctk
|
- dataset : Vctk
|
||||||
- optimizer : Adam
|
- optimizer : Adam
|
||||||
- hyperparameters : default
|
- hyperparameters : default
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,9 @@
|
||||||
loss : mse
|
loss : mse
|
||||||
metric : mae
|
metric : mae
|
||||||
lr : 0.0001
|
lr : 0.0001
|
||||||
|
num_epochs : 100
|
||||||
|
ReduceLr_patience : 5
|
||||||
|
ReduceLr_factor : 0.1
|
||||||
|
min_lr : 0.000001
|
||||||
|
EarlyStopping_factor : 10
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,8 @@ limit_predict_batches: 1.0
|
||||||
limit_test_batches: 1.0
|
limit_test_batches: 1.0
|
||||||
limit_train_batches: 1.0
|
limit_train_batches: 1.0
|
||||||
limit_val_batches: 1.0
|
limit_val_batches: 1.0
|
||||||
log_every_n_steps: 50
|
log_every_n_steps: 10
|
||||||
max_epochs: 500
|
max_epochs: 30
|
||||||
max_steps: null
|
max_steps: null
|
||||||
max_time: null
|
max_time: null
|
||||||
min_epochs: 1
|
min_epochs: 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue