fix logging
This commit is contained in:
commit
09ba645315
|
|
@ -4,7 +4,7 @@ from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
|
||||||
|
|
@ -39,21 +39,21 @@ def main(config: DictConfig):
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
dirpath="./model",
|
dirpath="./model",
|
||||||
filename=f"model_{JOB_ID}",
|
filename=f"model_{JOB_ID}",
|
||||||
monitor="val_loss",
|
monitor="valid_loss",
|
||||||
verbose=False,
|
verbose=False,
|
||||||
mode=direction,
|
mode=direction,
|
||||||
every_n_epochs=1,
|
every_n_epochs=1,
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
early_stopping = EarlyStopping(
|
# early_stopping = EarlyStopping(
|
||||||
monitor="val_loss",
|
# monitor="val_loss",
|
||||||
mode=direction,
|
# mode=direction,
|
||||||
min_delta=1e-7,
|
# min_delta=0.0,
|
||||||
patience=parameters.get("EarlyStopping_patience", 10),
|
# patience=parameters.get("EarlyStopping_patience", 10),
|
||||||
strict=True,
|
# strict=True,
|
||||||
verbose=True,
|
# verbose=False,
|
||||||
)
|
# )
|
||||||
callbacks.append(early_stopping)
|
# callbacks.append(early_stopping)
|
||||||
|
|
||||||
def configure_optimizer(self):
|
def configure_optimizer(self):
|
||||||
optimizer = instantiate(
|
optimizer = instantiate(
|
||||||
|
|
|
||||||
|
|
@ -151,9 +151,11 @@ class LossWrapper(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.higher_better = direction[0]
|
self.higher_better = direction[0]
|
||||||
|
self.name = ""
|
||||||
for loss in losses:
|
for loss in losses:
|
||||||
loss = self.validate_loss(loss)
|
loss = self.validate_loss(loss)
|
||||||
self.valid_losses.append(loss())
|
self.valid_losses.append(loss())
|
||||||
|
self.name += f"{loss().name}_"
|
||||||
|
|
||||||
def validate_loss(self, loss: str):
|
def validate_loss(self, loss: str):
|
||||||
if loss not in LOSS_MAP.keys():
|
if loss not in LOSS_MAP.keys():
|
||||||
|
|
|
||||||
|
|
@ -132,45 +132,40 @@ class Model(pl.LightningModule):
|
||||||
mixed_waveform = batch["noisy"]
|
mixed_waveform = batch["noisy"]
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
if (
|
self.log(
|
||||||
(self.logger)
|
"train_loss",
|
||||||
and (self.global_step > 50)
|
loss.item(),
|
||||||
and (self.global_step % 50 == 0)
|
on_epoch=True,
|
||||||
):
|
on_step=True,
|
||||||
self.logger.experiment.log_metric(
|
logger=True,
|
||||||
run_id=self.logger.run_id,
|
prog_bar=True,
|
||||||
key="train_loss",
|
)
|
||||||
value=loss.item(),
|
|
||||||
step=self.global_step,
|
|
||||||
)
|
|
||||||
self.log("train_loss", loss.item())
|
|
||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx: int):
|
def validation_step(self, batch, batch_idx: int):
|
||||||
|
|
||||||
|
metric_dict = {}
|
||||||
mixed_waveform = batch["noisy"]
|
mixed_waveform = batch["noisy"]
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
loss_val = self.loss(prediction, target)
|
metric_dict["valid_loss"] = self.loss(target, prediction).item()
|
||||||
self.log("val_loss", loss_val.item())
|
for metric in self.metric:
|
||||||
|
value = metric(target, prediction)
|
||||||
|
metric_dict[f"valid_{metric.name}"] = value.item()
|
||||||
|
|
||||||
if (
|
self.log_dict(
|
||||||
(self.logger)
|
metric_dict,
|
||||||
and (self.global_step > 50)
|
on_step=True,
|
||||||
and (self.global_step % 50 == 0)
|
on_epoch=True,
|
||||||
):
|
prog_bar=True,
|
||||||
self.logger.experiment.log_metric(
|
logger=True,
|
||||||
run_id=self.logger.run_id,
|
)
|
||||||
key="val_loss",
|
|
||||||
value=loss_val.item(),
|
|
||||||
step=self.global_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"loss": loss_val}
|
return metric_dict
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
|
|
||||||
|
|
@ -181,46 +176,18 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
for metric in self.metric:
|
for metric in self.metric:
|
||||||
value = metric(target, prediction)
|
value = metric(target, prediction)
|
||||||
metric_dict[metric.name] = value
|
metric_dict[f"test_{metric.name}"] = value
|
||||||
|
|
||||||
for k, v in metric_dict.items():
|
self.log_dict(
|
||||||
self.logger.experiment.log_metric(
|
metric_dict,
|
||||||
run_id=self.logger.run_id,
|
on_step=True,
|
||||||
key=k,
|
on_epoch=True,
|
||||||
value=v,
|
prog_bar=True,
|
||||||
step=self.global_step,
|
logger=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return metric_dict
|
return metric_dict
|
||||||
|
|
||||||
def training_epoch_end(self, outputs):
|
|
||||||
train_mean_loss = 0.0
|
|
||||||
for output in outputs:
|
|
||||||
train_mean_loss += output["loss"]
|
|
||||||
train_mean_loss /= len(outputs)
|
|
||||||
|
|
||||||
if self.logger:
|
|
||||||
self.logger.experiment.log_metric(
|
|
||||||
run_id=self.logger.run_id,
|
|
||||||
key="train_loss_epoch",
|
|
||||||
value=train_mean_loss,
|
|
||||||
step=self.current_epoch,
|
|
||||||
)
|
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
|
||||||
valid_mean_loss = 0.0
|
|
||||||
for output in outputs:
|
|
||||||
valid_mean_loss += output["loss"]
|
|
||||||
valid_mean_loss /= len(outputs)
|
|
||||||
|
|
||||||
if self.logger:
|
|
||||||
self.logger.experiment.log_metric(
|
|
||||||
run_id=self.logger.run_id,
|
|
||||||
key="valid_loss_epoch",
|
|
||||||
value=valid_mean_loss,
|
|
||||||
step=self.current_epoch,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
def test_epoch_end(self, outputs):
|
||||||
|
|
||||||
test_mean_metrics = defaultdict(int)
|
test_mean_metrics = defaultdict(int)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue