validation step
This commit is contained in:
parent
22eb9256e2
commit
a1103310f2
|
|
@ -14,11 +14,12 @@ class Model(pl.LightningModule):
|
||||||
sampling_rate:int=16000,
|
sampling_rate:int=16000,
|
||||||
lr:float=1e-3,
|
lr:float=1e-3,
|
||||||
dataset:Optional[Dataset]=None,
|
dataset:Optional[Dataset]=None,
|
||||||
loss: Union[str, List] = "mse"
|
loss: Union[str, List] = "mse",
|
||||||
|
metric:Union[str,List] = "mse"
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
assert num_channels ==1 , "Enhancer only support for mono channel models"
|
||||||
self.save_hyperparameters("num_channels","sampling_rate","lr","loss")
|
self.save_hyperparameters("num_channels","sampling_rate","lr","loss","metric")
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,15 +35,15 @@ class Model(pl.LightningModule):
|
||||||
if stage == "fit":
|
if stage == "fit":
|
||||||
self.dataset.setup(stage)
|
self.dataset.setup(stage)
|
||||||
self.dataset.model = self
|
self.dataset.model = self
|
||||||
self.setup_loss()
|
self.loss = self.setup_loss(self.hparams.loss)
|
||||||
|
self.metric = self.setup_loss(self.hparams.metric)
|
||||||
|
|
||||||
def setup_loss(self):
|
def setup_loss(self,loss):
|
||||||
|
|
||||||
loss = self.hparams.loss
|
|
||||||
if isinstance(loss,str):
|
if isinstance(loss,str):
|
||||||
losses = [loss]
|
losses = [loss]
|
||||||
|
|
||||||
self.loss = Avergeloss(losses)
|
return Avergeloss(losses)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.dataset.train_dataloader()
|
return self.dataset.train_dataloader()
|
||||||
|
|
@ -63,6 +64,15 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
|
def validation_step(self,batch,batch_idx:int):
|
||||||
|
|
||||||
|
mixed_waveform = batch["noisy"]
|
||||||
|
target = batch["clean"]
|
||||||
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
|
loss = self.metric(prediction, target)
|
||||||
|
|
||||||
|
return {"loss":loss}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,):
|
def from_pretrained(cls,):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue