merge dev

This commit is contained in:
shahules786 2022-10-11 16:50:09 +05:30
commit 031221b79e
2 changed files with 48 additions and 5 deletions

View File

@ -236,7 +236,7 @@ class Demucs(Model):
self.hparams.sampling_rate,
)
out = x[..., :length].clone()
out = x[..., :length]
return out
def get_padding_length(self, input_length):

View File

@ -1,7 +1,8 @@
import os
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, List, Optional, Text, Union
from typing import List, Optional, Text, Union
from urllib.parse import urlparse
import numpy as np
@ -192,6 +193,51 @@ class Model(pl.LightningModule):
return metric_dict
def training_epoch_end(self, outputs):
train_mean_loss = 0.0
for output in outputs:
train_mean_loss += output["loss"]
train_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="train_loss_epoch",
value=train_mean_loss,
step=self.current_epoch,
)
def validation_epoch_end(self, outputs):
valid_mean_loss = 0.0
for output in outputs:
valid_mean_loss += output["loss"]
valid_mean_loss /= len(outputs)
if self.logger:
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="valid_loss_epoch",
value=valid_mean_loss,
step=self.current_epoch,
)
def test_epoch_end(self, outputs):
test_mean_metrics = defaultdict(int)
for output in outputs:
for metric, value in output.items():
test_mean_metrics[metric] += value.item()
for metric in test_mean_metrics.keys():
test_mean_metrics[metric] /= len(outputs)
for k, v in test_mean_metrics.items():
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key=k,
value=v,
step=self.current_epoch,
)
def on_save_checkpoint(self, checkpoint):
checkpoint["enhancer"] = {
@ -202,9 +248,6 @@ class Model(pl.LightningModule):
},
}
def on_load_checkpoint(self, checkpoint: Dict[str, Any]):
pass
@classmethod
def from_pretrained(
cls,