diff --git a/enhancer/models/complexnn/utils.py b/enhancer/models/complexnn/utils.py index d5de558..1b6ff78 100644 --- a/enhancer/models/complexnn/utils.py +++ b/enhancer/models/complexnn/utils.py @@ -7,7 +7,7 @@ class ComplexBatchNorm2D(nn.Module): self, num_features: int, eps: float = 1e-5, - momentum: bool = True, + momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, ): @@ -25,12 +25,11 @@ class ComplexBatchNorm2D(nn.Module): self.eps = eps if self.affine: - values = torch.Tensor(self.num_features) - self.Wrr = nn.parameter.Parameter(values) - self.Wri = nn.parameter.Parameter(values) - self.Wii = nn.parameter.Parameter(values) - self.Br = nn.parameter.Parameter(values) - self.Bi = nn.parameter.Parameter(values) + self.Wrr = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wri = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Wii = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Br = nn.parameter.Parameter(torch.Tensor(self.num_features)) + self.Bi = nn.parameter.Parameter(torch.Tensor(self.num_features)) else: self.register_parameter("Wrr", None) self.register_parameter("Wri", None) @@ -39,7 +38,7 @@ class ComplexBatchNorm2D(nn.Module): self.register_parameter("Bi", None) if self.track_running_stats: - values = torch.Tensor(self.num_features) + values = torch.zeros(self.num_features) self.register_buffer("Mean_real", values) self.register_buffer("Mean_imag", values) self.register_buffer("Var_rr", values) @@ -111,8 +110,8 @@ class ComplexBatchNorm2D(nn.Module): batch_mean_real = self.Mean_real.view(vdim) batch_mean_imag = self.Mean_imag.view(vdim) - real -= batch_mean_real - imag -= batch_mean_imag + real = real - batch_mean_real + imag = imag - batch_mean_imag if training: batch_var_rr = real * real @@ -141,7 +140,7 @@ class ComplexBatchNorm2D(nn.Module): s = batch_var_rr * batch_var_ii - batch_var_ri * batch_var_ri t = (tau + 2 * s).sqrt() - rst = 1 / (s * t) + rst = (s * t).reciprocal() Urr = (batch_var_ii + s) * rst Uri = -batch_var_ri * rst Uii = (batch_var_rr + s) * rst @@ -162,6 +161,10 @@ class ComplexBatchNorm2D(nn.Module): yr = (Zrr * real) + (Zri * imag) yi = (Zir * real) + (Zii * imag) + if self.affine: + yr = yr + self.Br.view(vdim) + yi = yi + self.Bi.view(vdim) + outputs = torch.cat([yr, yi], 1) return outputs @@ -178,3 +181,15 @@ class ComplexRelu(nn.Module): real = self.real_relu(real) imag = self.imag_relu(imag) return torch.cat([real, imag], dim=1) + + +def complex_cat(inputs, axis=1): + + real, imag = [], [] + for data in inputs: + real_data, imag_data = torch.chunk(data, 2, axis) + real.append(real_data) + imag.append(imag_data) + real = torch.cat(real, axis) + imag = torch.cat(imag, axis) + return torch.cat([real, imag], axis)