diff --git a/enhancer/models/complexnn/rnn.py b/enhancer/models/complexnn/rnn.py index 7d19425..847030b 100644 --- a/enhancer/models/complexnn/rnn.py +++ b/enhancer/models/complexnn/rnn.py @@ -42,6 +42,8 @@ class ComplexLSTM(nn.Module): self.imag_linear = nn.Linear( self.hidden_size * bidirectional, self.projection_size ) + else: + self.projection_size = None def forward(self, input):