From fdce8bb601232ca5e485b061074c5333266a246e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 11 Oct 2022 15:11:50 +0530 Subject: [PATCH 1/2] rmv inplace operation --- enhancer/models/demucs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 5d7e99f..95d6a6f 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -226,7 +226,7 @@ class Demucs(Model): x = x.permute(0, 2, 1) for decoder in self.decoder: skip_connection = encoder_outputs.pop(-1) - x += skip_connection[..., : x.shape[-1]] + x = x + skip_connection[..., : x.shape[-1]] x = decoder(x) if self.hparams.resample > 1: @@ -236,7 +236,8 @@ class Demucs(Model): self.hparams.sampling_rate, ) - return x[..., :length] + out = x[..., :length] + return out def get_padding_length(self, input_length): From abcdc29309e83660ecd223e06a71283370087782 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 11 Oct 2022 16:48:49 +0530 Subject: [PATCH 2/2] log average metrics --- enhancer/models/model.py | 51 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 7dbc065..04f79a8 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,7 +1,8 @@ import os +from collections import defaultdict from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Optional, Text, Union +from typing import List, Optional, Text, Union from urllib.parse import urlparse import numpy as np @@ -192,6 +193,51 @@ class Model(pl.LightningModule): return metric_dict + def training_epoch_end(self, outputs): + train_mean_loss = 0.0 + for output in outputs: + train_mean_loss += output["loss"] + train_mean_loss /= len(outputs) + + if self.logger: + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="train_loss_epoch", + value=train_mean_loss, + step=self.current_epoch, + ) + + def validation_epoch_end(self, outputs): + valid_mean_loss = 0.0 + for output in outputs: + valid_mean_loss += output["loss"] + valid_mean_loss /= len(outputs) + + if self.logger: + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key="valid_loss_epoch", + value=valid_mean_loss, + step=self.current_epoch, + ) + + def test_epoch_end(self, outputs): + + test_mean_metrics = defaultdict(int) + for output in outputs: + for metric, value in output.items(): + test_mean_metrics[metric] += value.item() + for metric in test_mean_metrics.keys(): + test_mean_metrics[metric] /= len(outputs) + + for k, v in test_mean_metrics.items(): + self.logger.experiment.log_metric( + run_id=self.logger.run_id, + key=k, + value=v, + step=self.current_epoch, + ) + def on_save_checkpoint(self, checkpoint): checkpoint["enhancer"] = { @@ -202,9 +248,6 @@ class Model(pl.LightningModule): }, } - def on_load_checkpoint(self, checkpoint: Dict[str, Any]): - pass - @classmethod def from_pretrained( cls,