fix batchnorm eval() mode

This commit is contained in:
shahules786 2022-11-07 10:52:11 +05:30
parent 511d2141d4
commit 15c1d1ad94
1 changed files with 4 additions and 0 deletions

View File

@ -125,6 +125,10 @@ class ComplexBatchNorm2D(nn.Module):
self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor) self.Var_rr.lerp_(batch_var_rr.squeeze(), exp_avg_factor)
self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor) self.Var_ri.lerp_(batch_var_ri.squeeze(), exp_avg_factor)
self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor) self.Var_ii.lerp_(batch_var_ii.squeeze(), exp_avg_factor)
else:
batch_var_rr = self.Var_rr.view(vdim)
batch_var_ii = self.Var_ii.view(vdim)
batch_var_ri = self.Var_ri.view(vdim)
batch_var_rr += self.eps batch_var_rr += self.eps
batch_var_ii += self.eps batch_var_ii += self.eps