diff --git a/enhancer/utils/loss.py b/enhancer/utils/loss.py index c410914..a51669c 100644 --- a/enhancer/utils/loss.py +++ b/enhancer/utils/loss.py @@ -17,6 +17,7 @@ class mean_squared_error(nn.Module): class mean_absolute_error(nn.Module): def __init__(self,reduction="mean"): + super().__init__() self.loss_fun = nn.L1Loss(reduction=reduction) @@ -27,6 +28,7 @@ class mean_absolute_error(nn.Module): class Avergeloss(nn.Module): def __init__(self,losses): + super().__init__() self.valid_losses = nn.ModuleList() for loss in losses: