diff --git a/enhancer/models/complexnn/utils.py b/enhancer/models/complexnn/utils.py index 1b6ff78..0c28f9b 100644 --- a/enhancer/models/complexnn/utils.py +++ b/enhancer/models/complexnn/utils.py @@ -125,6 +125,10 @@ class ComplexBatchNorm2D(nn.Module): self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) + else: + batch_var_rr = self.Var_rr.view(vdim) + batch_var_ii = self.Var_ii.view(vdim) + batch_var_ri = self.Var_ri.view(vdim) batch_var_rr += self.eps batch_var_ii += self.eps