diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 5e12142..49f7b3b 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,7 +4,7 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -39,21 +39,21 @@ def main(config: DictConfig): checkpoint = ModelCheckpoint( dirpath="./model", filename=f"model_{JOB_ID}", - monitor="val_loss", + monitor="valid_loss", verbose=False, mode=direction, every_n_epochs=1, ) callbacks.append(checkpoint) - early_stopping = EarlyStopping( - monitor="val_loss", - mode=direction, - min_delta=1e-7, - patience=parameters.get("EarlyStopping_patience", 10), - strict=True, - verbose=True, - ) - callbacks.append(early_stopping) + # early_stopping = EarlyStopping( + # monitor="val_loss", + # mode=direction, + # min_delta=0.0, + # patience=parameters.get("EarlyStopping_patience", 10), + # strict=True, + # verbose=False, + # ) + # callbacks.append(early_stopping) def configure_optimizer(self): optimizer = instantiate( diff --git a/enhancer/loss.py b/enhancer/loss.py index fd1a609..cdd15a5 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -151,9 +151,11 @@ class LossWrapper(nn.Module): ) self.higher_better = direction[0] + self.name = "" for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss()) + self.name += f"{loss().name}_" def validate_loss(self, loss: str): if loss not in LOSS_MAP.keys(): diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 3f74a74..4f055b4 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -132,45 +132,40 @@ class Model(pl.LightningModule): mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) - loss = self.loss(prediction, target) - if ( - (self.logger) - and (self.global_step > 50) - and (self.global_step % 50 == 0) - ): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="train_loss", - value=loss.item(), - step=self.global_step, - ) - self.log("train_loss", loss.item()) + self.log( + "train_loss", + loss.item(), + on_epoch=True, + on_step=True, + logger=True, + prog_bar=True, + ) + return {"loss": loss} def validation_step(self, batch, batch_idx: int): + metric_dict = {} mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) - loss_val = self.loss(prediction, target) - self.log("val_loss", loss_val.item()) + metric_dict["valid_loss"] = self.loss(target, prediction).item() + for metric in self.metric: + value = metric(target, prediction) + metric_dict[f"valid_{metric.name}"] = value.item() - if ( - (self.logger) - and (self.global_step > 50) - and (self.global_step % 50 == 0) - ): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="val_loss", - value=loss_val.item(), - step=self.global_step, - ) + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) - return {"loss": loss_val} + return metric_dict def test_step(self, batch, batch_idx): @@ -181,46 +176,18 @@ class Model(pl.LightningModule): for metric in self.metric: value = metric(target, prediction) - metric_dict[metric.name] = value + metric_dict[f"test_{metric.name}"] = value - for k, v in metric_dict.items(): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key=k, - value=v, - step=self.global_step, - ) + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) return metric_dict - def training_epoch_end(self, outputs): - train_mean_loss = 0.0 - for output in outputs: - train_mean_loss += output["loss"] - train_mean_loss /= len(outputs) - - if self.logger: - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="train_loss_epoch", - value=train_mean_loss, - step=self.current_epoch, - ) - - def validation_epoch_end(self, outputs): - valid_mean_loss = 0.0 - for output in outputs: - valid_mean_loss += output["loss"] - valid_mean_loss /= len(outputs) - - if self.logger: - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="valid_loss_epoch", - value=valid_mean_loss, - step=self.current_epoch, - ) - def test_epoch_end(self, outputs): test_mean_metrics = defaultdict(int)