69 lines
1.7 KiB
Python
69 lines
1.7 KiB
Python
from typing import Optional, Union, List
|
|
from torch.optim import Adam
|
|
import pytorch_lightning as pl
|
|
|
|
from enhancer.data.dataset import Dataset
|
|
from enhancer.utils.loss import Avergeloss
|
|
|
|
|
|
class Model(pl.LightningModule):
|
|
|
|
def __init__(
|
|
self,
|
|
num_channels:int=1,
|
|
sampling_rate:int=16000,
|
|
lr:float=1e-3,
|
|
dataset:Optional[Dataset]=None,
|
|
loss: Union[str, List] = "mse"
|
|
):
|
|
super().__init__()
|
|
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
|
self.save_hyperparameters("num_channels","sampling_rate","lr","loss")
|
|
self.dataset = dataset
|
|
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self._dataset
|
|
|
|
@dataset.setter
|
|
def dataset(self,dataset):
|
|
self._dataset = dataset
|
|
|
|
def setup(self,stage:Optional[str]=None):
|
|
if stage == "fit":
|
|
self.dataset.setup(stage)
|
|
self.dataset.model = self
|
|
self.setup_loss()
|
|
|
|
def setup_loss(self):
|
|
|
|
loss = self.hparams.loss
|
|
if isinstance(loss,str):
|
|
losses = [loss]
|
|
|
|
self.loss = Avergeloss(losses)
|
|
|
|
def train_dataloader(self):
|
|
return self.dataset.train_dataloader()
|
|
|
|
def val_dataloader(self):
|
|
return self.dataset.val_dataloader()
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.parameters(), lr = self.hparams.lr)
|
|
|
|
def training_step(self,batch, batch_idx:int):
|
|
|
|
mixed_waveform = batch["noisy"]
|
|
target = batch["clean"]
|
|
prediction = self(mixed_waveform)
|
|
|
|
loss = self.loss(prediction, target)
|
|
|
|
return {"loss":loss}
|
|
|
|
|
|
@classmethod
|
|
def from_pretrained(cls,):
|
|
pass |