set training step
This commit is contained in:
parent
f8c8884ce9
commit
619d0be8ce
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Union, List
|
||||
from torch.optim import Adam
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from enhancer.data.dataset import Dataset
|
||||
from enhancer.utils.loss import LOSS_MAP, Avergeloss
|
||||
|
||||
|
||||
class Model(pl.LightningModule):
|
||||
|
|
@ -13,10 +14,11 @@ class Model(pl.LightningModule):
|
|||
sampling_rate:int=16000,
|
||||
lr:float=1e-3,
|
||||
dataset:Optional[Dataset]=None,
|
||||
loss: Union[str, List] = "mse"
|
||||
):
|
||||
super().__init__()
|
||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||
self.save_hyperparameters("num_channels","sampling_rate","lr")
|
||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss")
|
||||
self.dataset = dataset
|
||||
|
||||
|
||||
|
|
@ -32,7 +34,15 @@ class Model(pl.LightningModule):
|
|||
if stage == "fit":
|
||||
self.dataset.setup(stage)
|
||||
self.dataset.model = self
|
||||
self.setup_loss()
|
||||
|
||||
def setup_loss(self):
|
||||
|
||||
loss = self.hparams.loss
|
||||
if isinstance(loss,str):
|
||||
losses = [loss]
|
||||
|
||||
self.loss = Avergeloss(losses)
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.dataset.train_dataloader()
|
||||
|
|
@ -44,7 +54,15 @@ class Model(pl.LightningModule):
|
|||
return Adam(self.parameters, lr = self.hparams.lr)
|
||||
|
||||
def training_step(self,batch, batch_idx:int):
|
||||
pass
|
||||
|
||||
mixed_waveform = batch["noisy"]
|
||||
target = batch["clean"]
|
||||
prediction = self(mixed_waveform)
|
||||
|
||||
loss = self.loss(prediction, target)
|
||||
|
||||
return {"loss":loss}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,):
|
||||
|
|
|
|||
Loading…
Reference in New Issue