merge dev

This commit is contained in:
shahules786 2022-10-11 16:50:09 +05:30
commit 031221b79e
2 changed files with 48 additions and 5 deletions

View File

@ -236,7 +236,7 @@ class Demucs(Model):
self.hparams.sampling_rate, self.hparams.sampling_rate,
) )
out = x[..., :length].clone() out = x[..., :length]
return out return out
def get_padding_length(self, input_length): def get_padding_length(self, input_length):

View File

@ -1,7 +1,8 @@
import os import os
from collections import defaultdict
from importlib import import_module from importlib import import_module
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Text, Union from typing import List, Optional, Text, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
@ -192,6 +193,51 @@ class Model(pl.LightningModule):
return metric_dict return metric_dict
def training_epoch_end(self, outputs):
train_mean_loss = 0.0
for output in outputs:
train_mean_loss += output["loss"]
train_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="train_loss_epoch",
value=train_mean_loss,
step=self.current_epoch,
)
def validation_epoch_end(self, outputs):
valid_mean_loss = 0.0
for output in outputs:
valid_mean_loss += output["loss"]
valid_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="valid_loss_epoch",
value=valid_mean_loss,
step=self.current_epoch,
)
def test_epoch_end(self, outputs):
test_mean_metrics = defaultdict(int)
for output in outputs:
for metric, value in output.items():
test_mean_metrics[metric] += value.item()
for metric in test_mean_metrics.keys():
test_mean_metrics[metric] /= len(outputs)
for k, v in test_mean_metrics.items():
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key=k,
value=v,
step=self.current_epoch,
)
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = { checkpoint["enhancer"] = {
@ -202,9 +248,6 @@ class Model(pl.LightningModule):
}, },
} }
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,