79 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			79 lines
		
	
	
		
			2.0 KiB
		
	
	
	
		
			Python
		
	
	
	
| from typing import Optional, Union, List
 | |
| from torch.optim import Adam
 | |
| import pytorch_lightning as pl
 | |
| 
 | |
| from enhancer.data.dataset import Dataset
 | |
| from enhancer.utils.loss import Avergeloss
 | |
| 
 | |
| 
 | |
| class Model(pl.LightningModule):
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         num_channels:int=1,
 | |
|         sampling_rate:int=16000,
 | |
|         lr:float=1e-3,
 | |
|         dataset:Optional[Dataset]=None,
 | |
|         loss: Union[str, List] = "mse",
 | |
|         metric:Union[str,List] = "mse"
 | |
|     ):
 | |
|         super().__init__()
 | |
|         assert num_channels ==1 , "Enhancer only support for mono channel models"
 | |
|         self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric")
 | |
|         self.dataset = dataset
 | |
|         
 | |
|     
 | |
|     @property
 | |
|     def dataset(self):
 | |
|         return self._dataset
 | |
| 
 | |
|     @dataset.setter
 | |
|     def dataset(self,dataset):
 | |
|         self._dataset = dataset
 | |
| 
 | |
|     def setup(self,stage:Optional[str]=None):
 | |
|         if stage == "fit":
 | |
|             self.dataset.setup(stage)
 | |
|             self.dataset.model = self
 | |
|             self.loss = self.setup_loss(self.hparams.loss) 
 | |
|             self.metric = self.setup_loss(self.hparams.metric)
 | |
|         
 | |
|     def setup_loss(self,loss):
 | |
| 
 | |
|         if isinstance(loss,str):
 | |
|             losses = [loss]
 | |
|         
 | |
|         return Avergeloss(losses)
 | |
| 
 | |
|     def train_dataloader(self):
 | |
|         return self.dataset.train_dataloader()
 | |
| 
 | |
|     def val_dataloader(self):
 | |
|         return self.dataset.val_dataloader()
 | |
| 
 | |
|     def configure_optimizers(self):
 | |
|         return Adam(self.parameters(), lr = self.hparams.lr)
 | |
| 
 | |
|     def training_step(self,batch, batch_idx:int):
 | |
| 
 | |
|         mixed_waveform = batch["noisy"]
 | |
|         target = batch["clean"]
 | |
|         prediction = self(mixed_waveform)
 | |
| 
 | |
|         loss = self.loss(prediction, target)
 | |
| 
 | |
|         return {"loss":loss}
 | |
| 
 | |
|     def validation_step(self,batch,batch_idx:int):
 | |
| 
 | |
|         mixed_waveform = batch["noisy"]
 | |
|         target = batch["clean"]
 | |
|         prediction = self(mixed_waveform)
 | |
| 
 | |
|         loss = self.metric(prediction, target)
 | |
| 
 | |
|         return {"loss":loss}
 | |
| 
 | |
|     @classmethod
 | |
|     def from_pretrained(cls,):
 | |
|         pass |