add test step
This commit is contained in:
		
							parent
							
								
									1aca956ed4
								
							
						
					
					
						commit
						2587d5b867
					
				|  | @ -13,7 +13,7 @@ from torch.optim import Adam | ||||||
| 
 | 
 | ||||||
| from enhancer.data.dataset import EnhancerDataset | from enhancer.data.dataset import EnhancerDataset | ||||||
| from enhancer.inference import Inference | from enhancer.inference import Inference | ||||||
| from enhancer.loss import Avergeloss | from enhancer.loss import LOSS_MAP, LossWrapper | ||||||
| from enhancer.version import __version__ | from enhancer.version import __version__ | ||||||
| 
 | 
 | ||||||
| CACHE_DIR = "" | CACHE_DIR = "" | ||||||
|  | @ -76,7 +76,7 @@ class Model(pl.LightningModule): | ||||||
|         if isinstance(loss, str): |         if isinstance(loss, str): | ||||||
|             loss = [loss] |             loss = [loss] | ||||||
| 
 | 
 | ||||||
|         self._loss = Avergeloss(loss) |         self._loss = LossWrapper(loss) | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def metric(self): |     def metric(self): | ||||||
|  | @ -84,11 +84,21 @@ class Model(pl.LightningModule): | ||||||
| 
 | 
 | ||||||
|     @metric.setter |     @metric.setter | ||||||
|     def metric(self, metric): |     def metric(self, metric): | ||||||
| 
 |         self._metric = [] | ||||||
|         if isinstance(metric, str): |         if isinstance(metric, str): | ||||||
|             metric = [metric] |             metric = [metric] | ||||||
| 
 | 
 | ||||||
|         self._metric = Avergeloss(metric) |         for func in metric: | ||||||
|  |             if func in LOSS_MAP.keys(): | ||||||
|  |                 if func in ("pesq", "stoi"): | ||||||
|  |                     self._metric.append( | ||||||
|  |                         LOSS_MAP[func](self.hparams.sampling_rate) | ||||||
|  |                     ) | ||||||
|  |                 else: | ||||||
|  |                     self._metric.append(LOSS_MAP[func]()) | ||||||
|  | 
 | ||||||
|  |             else: | ||||||
|  |                 raise ValueError(f"Invalid metrics {func}") | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def dataset(self): |     def dataset(self): | ||||||
|  | @ -109,6 +119,9 @@ class Model(pl.LightningModule): | ||||||
|     def val_dataloader(self): |     def val_dataloader(self): | ||||||
|         return self.dataset.val_dataloader() |         return self.dataset.val_dataloader() | ||||||
| 
 | 
 | ||||||
|  |     def test_dataloader(self): | ||||||
|  |         return self.dataset.test_dataloader() | ||||||
|  | 
 | ||||||
|     def configure_optimizers(self): |     def configure_optimizers(self): | ||||||
|         return Adam(self.parameters(), lr=self.hparams.lr) |         return Adam(self.parameters(), lr=self.hparams.lr) | ||||||
| 
 | 
 | ||||||
|  | @ -140,9 +153,7 @@ class Model(pl.LightningModule): | ||||||
|         target = batch["clean"] |         target = batch["clean"] | ||||||
|         prediction = self(mixed_waveform) |         prediction = self(mixed_waveform) | ||||||
| 
 | 
 | ||||||
|         metric_val = self.metric(prediction, target) |  | ||||||
|         loss_val = self.loss(prediction, target) |         loss_val = self.loss(prediction, target) | ||||||
|         self.log("val_metric", metric_val.item()) |  | ||||||
|         self.log("val_loss", loss_val.item()) |         self.log("val_loss", loss_val.item()) | ||||||
| 
 | 
 | ||||||
|         if ( |         if ( | ||||||
|  | @ -156,15 +167,28 @@ class Model(pl.LightningModule): | ||||||
|                 value=loss_val.item(), |                 value=loss_val.item(), | ||||||
|                 step=self.global_step, |                 step=self.global_step, | ||||||
|             ) |             ) | ||||||
|             self.logger.experiment.log_metric( |  | ||||||
|                 run_id=self.logger.run_id, |  | ||||||
|                 key="val_metric", |  | ||||||
|                 value=metric_val.item(), |  | ||||||
|                 step=self.global_step, |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|         return {"loss": loss_val} |         return {"loss": loss_val} | ||||||
| 
 | 
 | ||||||
|  |     def test_step(self, batch, batch_idx): | ||||||
|  | 
 | ||||||
|  |         metric_dict = {} | ||||||
|  |         mixed_waveform = batch["noisy"] | ||||||
|  |         target = batch["clean"] | ||||||
|  |         prediction = self(mixed_waveform) | ||||||
|  | 
 | ||||||
|  |         for metric in self.metric: | ||||||
|  |             value = metric(target, prediction) | ||||||
|  |             metric_dict[metric.name] = value | ||||||
|  | 
 | ||||||
|  |         self.logger.experiment.log_metrics( | ||||||
|  |             run_id=self.logger.run_id, | ||||||
|  |             metrics=metric_dict, | ||||||
|  |             step=self.global_step, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         return metric_dict | ||||||
|  | 
 | ||||||
|     def on_save_checkpoint(self, checkpoint): |     def on_save_checkpoint(self, checkpoint): | ||||||
| 
 | 
 | ||||||
|         checkpoint["enhancer"] = { |         checkpoint["enhancer"] = { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786