add test step
This commit is contained in:
parent
1aca956ed4
commit
2587d5b867
|
|
@ -13,7 +13,7 @@ from torch.optim import Adam
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from enhancer.inference import Inference
|
from enhancer.inference import Inference
|
||||||
from enhancer.loss import Avergeloss
|
from enhancer.loss import LOSS_MAP, LossWrapper
|
||||||
from enhancer.version import __version__
|
from enhancer.version import __version__
|
||||||
|
|
||||||
CACHE_DIR = ""
|
CACHE_DIR = ""
|
||||||
|
|
@ -76,7 +76,7 @@ class Model(pl.LightningModule):
|
||||||
if isinstance(loss, str):
|
if isinstance(loss, str):
|
||||||
loss = [loss]
|
loss = [loss]
|
||||||
|
|
||||||
self._loss = Avergeloss(loss)
|
self._loss = LossWrapper(loss)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric(self):
|
def metric(self):
|
||||||
|
|
@ -84,11 +84,21 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
@metric.setter
|
@metric.setter
|
||||||
def metric(self, metric):
|
def metric(self, metric):
|
||||||
|
self._metric = []
|
||||||
if isinstance(metric, str):
|
if isinstance(metric, str):
|
||||||
metric = [metric]
|
metric = [metric]
|
||||||
|
|
||||||
self._metric = Avergeloss(metric)
|
for func in metric:
|
||||||
|
if func in LOSS_MAP.keys():
|
||||||
|
if func in ("pesq", "stoi"):
|
||||||
|
self._metric.append(
|
||||||
|
LOSS_MAP[func](self.hparams.sampling_rate)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._metric.append(LOSS_MAP[func]())
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid metrics {func}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset(self):
|
def dataset(self):
|
||||||
|
|
@ -109,6 +119,9 @@ class Model(pl.LightningModule):
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return self.dataset.val_dataloader()
|
return self.dataset.val_dataloader()
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
return self.dataset.test_dataloader()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return Adam(self.parameters(), lr=self.hparams.lr)
|
return Adam(self.parameters(), lr=self.hparams.lr)
|
||||||
|
|
||||||
|
|
@ -140,9 +153,7 @@ class Model(pl.LightningModule):
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
metric_val = self.metric(prediction, target)
|
|
||||||
loss_val = self.loss(prediction, target)
|
loss_val = self.loss(prediction, target)
|
||||||
self.log("val_metric", metric_val.item())
|
|
||||||
self.log("val_loss", loss_val.item())
|
self.log("val_loss", loss_val.item())
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
@ -156,15 +167,28 @@ class Model(pl.LightningModule):
|
||||||
value=loss_val.item(),
|
value=loss_val.item(),
|
||||||
step=self.global_step,
|
step=self.global_step,
|
||||||
)
|
)
|
||||||
self.logger.experiment.log_metric(
|
|
||||||
run_id=self.logger.run_id,
|
|
||||||
key="val_metric",
|
|
||||||
value=metric_val.item(),
|
|
||||||
step=self.global_step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"loss": loss_val}
|
return {"loss": loss_val}
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
|
||||||
|
metric_dict = {}
|
||||||
|
mixed_waveform = batch["noisy"]
|
||||||
|
target = batch["clean"]
|
||||||
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
|
for metric in self.metric:
|
||||||
|
value = metric(target, prediction)
|
||||||
|
metric_dict[metric.name] = value
|
||||||
|
|
||||||
|
self.logger.experiment.log_metrics(
|
||||||
|
run_id=self.logger.run_id,
|
||||||
|
metrics=metric_dict,
|
||||||
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metric_dict
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
checkpoint["enhancer"] = {
|
checkpoint["enhancer"] = {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue