add test step
This commit is contained in:
parent
1aca956ed4
commit
2587d5b867
|
|
@ -13,7 +13,7 @@ from torch.optim import Adam
|
|||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.inference import Inference
|
||||
from enhancer.loss import Avergeloss
|
||||
from enhancer.loss import LOSS_MAP, LossWrapper
|
||||
from enhancer.version import __version__
|
||||
|
||||
CACHE_DIR = ""
|
||||
|
|
@ -76,7 +76,7 @@ class Model(pl.LightningModule):
|
|||
if isinstance(loss, str):
|
||||
loss = [loss]
|
||||
|
||||
self._loss = Avergeloss(loss)
|
||||
self._loss = LossWrapper(loss)
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
|
|
@ -84,11 +84,21 @@ class Model(pl.LightningModule):
|
|||
|
||||
@metric.setter
|
||||
def metric(self, metric):
|
||||
|
||||
self._metric = []
|
||||
if isinstance(metric, str):
|
||||
metric = [metric]
|
||||
|
||||
self._metric = Avergeloss(metric)
|
||||
for func in metric:
|
||||
if func in LOSS_MAP.keys():
|
||||
if func in ("pesq", "stoi"):
|
||||
self._metric.append(
|
||||
LOSS_MAP[func](self.hparams.sampling_rate)
|
||||
)
|
||||
else:
|
||||
self._metric.append(LOSS_MAP[func]())
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid metrics {func}")
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
|
|
@ -109,6 +119,9 @@ class Model(pl.LightningModule):
|
|||
def val_dataloader(self):
|
||||
return self.dataset.val_dataloader()
|
||||
|
||||
def test_dataloader(self):
|
||||
return self.dataset.test_dataloader()
|
||||
|
||||
def configure_optimizers(self):
|
||||
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||
|
||||
|
|
@ -140,9 +153,7 @@ class Model(pl.LightningModule):
|
|||
target = batch["clean"]
|
||||
prediction = self(mixed_waveform)
|
||||
|
||||
metric_val = self.metric(prediction, target)
|
||||
loss_val = self.loss(prediction, target)
|
||||
self.log("val_metric", metric_val.item())
|
||||
self.log("val_loss", loss_val.item())
|
||||
|
||||
if (
|
||||
|
|
@ -156,15 +167,28 @@ class Model(pl.LightningModule):
|
|||
value=loss_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_metric",
|
||||
value=metric_val.item(),
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
return {"loss": loss_val}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
|
||||
metric_dict = {}
|
||||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
prediction = self(mixed_waveform)
|
||||
|
||||
for metric in self.metric:
|
||||
value = metric(target, prediction)
|
||||
metric_dict[metric.name] = value
|
||||
|
||||
self.logger.experiment.log_metrics(
|
||||
run_id=self.logger.run_id,
|
||||
metrics=metric_dict,
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
return metric_dict
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
checkpoint["enhancer"] = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue