add test step
This commit is contained in:
		
							parent
							
								
									1aca956ed4
								
							
						
					
					
						commit
						2587d5b867
					
				|  | @ -13,7 +13,7 @@ from torch.optim import Adam | |||
| 
 | ||||
| from enhancer.data.dataset import EnhancerDataset | ||||
| from enhancer.inference import Inference | ||||
| from enhancer.loss import Avergeloss | ||||
| from enhancer.loss import LOSS_MAP, LossWrapper | ||||
| from enhancer.version import __version__ | ||||
| 
 | ||||
| CACHE_DIR = "" | ||||
|  | @ -76,7 +76,7 @@ class Model(pl.LightningModule): | |||
|         if isinstance(loss, str): | ||||
|             loss = [loss] | ||||
| 
 | ||||
|         self._loss = Avergeloss(loss) | ||||
|         self._loss = LossWrapper(loss) | ||||
| 
 | ||||
|     @property | ||||
|     def metric(self): | ||||
|  | @ -84,11 +84,21 @@ class Model(pl.LightningModule): | |||
| 
 | ||||
|     @metric.setter | ||||
|     def metric(self, metric): | ||||
| 
 | ||||
|         self._metric = [] | ||||
|         if isinstance(metric, str): | ||||
|             metric = [metric] | ||||
| 
 | ||||
|         self._metric = Avergeloss(metric) | ||||
|         for func in metric: | ||||
|             if func in LOSS_MAP.keys(): | ||||
|                 if func in ("pesq", "stoi"): | ||||
|                     self._metric.append( | ||||
|                         LOSS_MAP[func](self.hparams.sampling_rate) | ||||
|                     ) | ||||
|                 else: | ||||
|                     self._metric.append(LOSS_MAP[func]()) | ||||
| 
 | ||||
|             else: | ||||
|                 raise ValueError(f"Invalid metrics {func}") | ||||
| 
 | ||||
|     @property | ||||
|     def dataset(self): | ||||
|  | @ -109,6 +119,9 @@ class Model(pl.LightningModule): | |||
|     def val_dataloader(self): | ||||
|         return self.dataset.val_dataloader() | ||||
| 
 | ||||
|     def test_dataloader(self): | ||||
|         return self.dataset.test_dataloader() | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         return Adam(self.parameters(), lr=self.hparams.lr) | ||||
| 
 | ||||
|  | @ -140,9 +153,7 @@ class Model(pl.LightningModule): | |||
|         target = batch["clean"] | ||||
|         prediction = self(mixed_waveform) | ||||
| 
 | ||||
|         metric_val = self.metric(prediction, target) | ||||
|         loss_val = self.loss(prediction, target) | ||||
|         self.log("val_metric", metric_val.item()) | ||||
|         self.log("val_loss", loss_val.item()) | ||||
| 
 | ||||
|         if ( | ||||
|  | @ -156,14 +167,27 @@ class Model(pl.LightningModule): | |||
|                 value=loss_val.item(), | ||||
|                 step=self.global_step, | ||||
|             ) | ||||
|             self.logger.experiment.log_metric( | ||||
| 
 | ||||
|         return {"loss": loss_val} | ||||
| 
 | ||||
|     def test_step(self, batch, batch_idx): | ||||
| 
 | ||||
|         metric_dict = {} | ||||
|         mixed_waveform = batch["noisy"] | ||||
|         target = batch["clean"] | ||||
|         prediction = self(mixed_waveform) | ||||
| 
 | ||||
|         for metric in self.metric: | ||||
|             value = metric(target, prediction) | ||||
|             metric_dict[metric.name] = value | ||||
| 
 | ||||
|         self.logger.experiment.log_metrics( | ||||
|             run_id=self.logger.run_id, | ||||
|                 key="val_metric", | ||||
|                 value=metric_val.item(), | ||||
|             metrics=metric_dict, | ||||
|             step=self.global_step, | ||||
|         ) | ||||
| 
 | ||||
|         return {"loss": loss_val} | ||||
|         return metric_dict | ||||
| 
 | ||||
|     def on_save_checkpoint(self, checkpoint): | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786