diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 62384fc..89b4bb7 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -61,7 +61,7 @@ class WaveUnet(Model): def __init__( self, num_channels:int=1, - num_layers:int=12, + depth:int=12, initial_output_channels:int=24, sampling_rate:int=16000, lr:float=1e-3, @@ -74,25 +74,25 @@ class WaveUnet(Model): sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric ) - self.save_hyperparameters("num_layers") + self.save_hyperparameters("depth") self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() out_channels = initial_output_channels - for layer in range(num_layers): + for layer in range(depth): encoder = WavenetEncoder(num_channels,out_channels) self.encoders.append(encoder) num_channels = out_channels out_channels += initial_output_channels - if layer == num_layers -1 : - decoder = WavenetDecoder(num_layers * initial_output_channels + num_channels,num_channels) + if layer == depth -1 : + decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels) else: decoder = WavenetDecoder(num_channels+out_channels,num_channels) self.decoders.insert(0,decoder) - bottleneck_dim = num_layers * initial_output_channels + bottleneck_dim = depth * initial_output_channels self.bottleneck = nn.Sequential( nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, padding=7),