diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 980c583..64bf201 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -38,9 +38,6 @@ class Model(pl.LightningModule): super().__init__() assert num_channels ==1 , "Enhancer only support for mono channel models" self.dataset = dataset - if self.dataset is not None: - sampling_rate = self.dataset.sampling_rate - logging.warn("Setting model sampling rate same as dataset sampling rate") self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric","duration") if self.logger: self.logger.experiment.log_dict(dict(self.hparams),"hyperparameters.json") @@ -86,7 +83,7 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step) + self.logger.experiment.log_metric("train_loss",loss.item(), step=self.global_step) return {"loss":loss} @@ -98,7 +95,7 @@ class Model(pl.LightningModule): loss = self.metric(prediction, target) if self.logger: - self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step) + self.logger.experiment.log_metric("val_loss",loss.item(), step=self.global_step) return {"loss":loss}