add test step

This commit is contained in:
shahules786 2022-10-10 12:47:24 +05:30
parent 1aca956ed4
commit 2587d5b867
1 changed files with 36 additions and 12 deletions

View File

@ -13,7 +13,7 @@ from torch.optim import Adam
from enhancer.data.dataset import EnhancerDataset
from enhancer.inference import Inference
from enhancer.loss import Avergeloss
from enhancer.loss import LOSS_MAP, LossWrapper
from enhancer.version import __version__
CACHE_DIR = ""
@ -76,7 +76,7 @@ class Model(pl.LightningModule):
if isinstance(loss, str):
loss = [loss]
self._loss = Avergeloss(loss)
self._loss = LossWrapper(loss)
@property
def metric(self):
@ -84,11 +84,21 @@ class Model(pl.LightningModule):
@metric.setter
def metric(self, metric):
self._metric = []
if isinstance(metric, str):
metric = [metric]
self._metric = Avergeloss(metric)
for func in metric:
if func in LOSS_MAP.keys():
if func in ("pesq", "stoi"):
self._metric.append(
LOSS_MAP[func](self.hparams.sampling_rate)
)
else:
self._metric.append(LOSS_MAP[func]())
else:
raise ValueError(f"Invalid metrics {func}")
@property
def dataset(self):
@ -109,6 +119,9 @@ class Model(pl.LightningModule):
def val_dataloader(self):
return self.dataset.val_dataloader()
def test_dataloader(self):
return self.dataset.test_dataloader()
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.lr)
@ -140,9 +153,7 @@ class Model(pl.LightningModule):
target = batch["clean"]
prediction = self(mixed_waveform)
metric_val = self.metric(prediction, target)
loss_val = self.loss(prediction, target)
self.log("val_metric", metric_val.item())
self.log("val_loss", loss_val.item())
if (
@ -156,14 +167,27 @@ class Model(pl.LightningModule):
value=loss_val.item(),
step=self.global_step,
)
self.logger.experiment.log_metric(
return {"loss": loss_val}
def test_step(self, batch, batch_idx):
metric_dict = {}
mixed_waveform = batch["noisy"]
target = batch["clean"]
prediction = self(mixed_waveform)
for metric in self.metric:
value = metric(target, prediction)
metric_dict[metric.name] = value
self.logger.experiment.log_metrics(
run_id=self.logger.run_id,
key="val_metric",
value=metric_val.item(),
metrics=metric_dict,
step=self.global_step,
)
return {"loss": loss_val}
return metric_dict
def on_save_checkpoint(self, checkpoint):