fix logging

This commit is contained in:
shahules786 2022-10-12 20:23:55 +05:30
commit 09ba645315
3 changed files with 43 additions and 74 deletions

View File

@ -4,7 +4,7 @@ from types import MethodType
import hydra import hydra
from hydra.utils import instantiate from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
@ -39,21 +39,21 @@ def main(config: DictConfig):
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
dirpath="./model", dirpath="./model",
filename=f"model_{JOB_ID}", filename=f"model_{JOB_ID}",
monitor="val_loss", monitor="valid_loss",
verbose=False, verbose=False,
mode=direction, mode=direction,
every_n_epochs=1, every_n_epochs=1,
) )
callbacks.append(checkpoint) callbacks.append(checkpoint)
early_stopping = EarlyStopping( # early_stopping = EarlyStopping(
monitor="val_loss", # monitor="val_loss",
mode=direction, # mode=direction,
min_delta=1e-7, # min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10), # patience=parameters.get("EarlyStopping_patience", 10),
strict=True, # strict=True,
verbose=True, # verbose=False,
) # )
callbacks.append(early_stopping) # callbacks.append(early_stopping)
def configure_optimizer(self): def configure_optimizer(self):
optimizer = instantiate( optimizer = instantiate(

View File

@ -151,9 +151,11 @@ class LossWrapper(nn.Module):
) )
self.higher_better = direction[0] self.higher_better = direction[0]
self.name = ""
for loss in losses: for loss in losses:
loss = self.validate_loss(loss) loss = self.validate_loss(loss)
self.valid_losses.append(loss()) self.valid_losses.append(loss())
self.name += f"{loss().name}_"
def validate_loss(self, loss: str): def validate_loss(self, loss: str):
if loss not in LOSS_MAP.keys(): if loss not in LOSS_MAP.keys():

View File

@ -132,45 +132,40 @@ class Model(pl.LightningModule):
mixed_waveform = batch["noisy"] mixed_waveform = batch["noisy"]
target = batch["clean"] target = batch["clean"]
prediction = self(mixed_waveform) prediction = self(mixed_waveform)
loss = self.loss(prediction, target) loss = self.loss(prediction, target)
if ( self.log(
(self.logger) "train_loss",
and (self.global_step > 50) loss.item(),
and (self.global_step % 50 == 0) on_epoch=True,
): on_step=True,
self.logger.experiment.log_metric( logger=True,
run_id=self.logger.run_id, prog_bar=True,
key="train_loss",
value=loss.item(),
step=self.global_step,
) )
self.log("train_loss", loss.item())
return {"loss": loss} return {"loss": loss}
def validation_step(self, batch, batch_idx: int): def validation_step(self, batch, batch_idx: int):
metric_dict = {}
mixed_waveform = batch["noisy"] mixed_waveform = batch["noisy"]
target = batch["clean"] target = batch["clean"]
prediction = self(mixed_waveform) prediction = self(mixed_waveform)
loss_val = self.loss(prediction, target) metric_dict["valid_loss"] = self.loss(target, prediction).item()
self.log("val_loss", loss_val.item()) for metric in self.metric:
value = metric(target, prediction)
metric_dict[f"valid_{metric.name}"] = value.item()
if ( self.log_dict(
(self.logger) metric_dict,
and (self.global_step > 50) on_step=True,
and (self.global_step % 50 == 0) on_epoch=True,
): prog_bar=True,
self.logger.experiment.log_metric( logger=True,
run_id=self.logger.run_id,
key="val_loss",
value=loss_val.item(),
step=self.global_step,
) )
return {"loss": loss_val} return metric_dict
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
@ -181,46 +176,18 @@ class Model(pl.LightningModule):
for metric in self.metric: for metric in self.metric:
value = metric(target, prediction) value = metric(target, prediction)
metric_dict[metric.name] = value metric_dict[f"test_{metric.name}"] = value
for k, v in metric_dict.items(): self.log_dict(
self.logger.experiment.log_metric( metric_dict,
run_id=self.logger.run_id, on_step=True,
key=k, on_epoch=True,
value=v, prog_bar=True,
step=self.global_step, logger=True,
) )
return metric_dict return metric_dict
def training_epoch_end(self, outputs):
train_mean_loss = 0.0
for output in outputs:
train_mean_loss += output["loss"]
train_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="train_loss_epoch",
value=train_mean_loss,
step=self.current_epoch,
)
def validation_epoch_end(self, outputs):
valid_mean_loss = 0.0
for output in outputs:
valid_mean_loss += output["loss"]
valid_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="valid_loss_epoch",
value=valid_mean_loss,
step=self.current_epoch,
)
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
test_mean_metrics = defaultdict(int) test_mean_metrics = defaultdict(int)