diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 4a0e20f..62384fc 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -86,7 +86,7 @@ class WaveUnet(Model): num_channels = out_channels out_channels += initial_output_channels if layer == num_layers -1 : - decoder = WavenetDecoder(num_layers * initial_output_channels + num_channels,inp_channels) + decoder = WavenetDecoder(num_layers * initial_output_channels + num_channels,num_channels) else: decoder = WavenetDecoder(num_channels+out_channels,num_channels)