rename num_layers to depth

This commit is contained in:
shahules786 2022-09-27 12:54:50 +05:30
parent b55e12d15c
commit b742756311
1 changed files with 6 additions and 6 deletions

View File

@ -61,7 +61,7 @@ class WaveUnet(Model):
def __init__(
self,
num_channels:int=1,
num_layers:int=12,
depth:int=12,
initial_output_channels:int=24,
sampling_rate:int=16000,
lr:float=1e-3,
@ -74,25 +74,25 @@ class WaveUnet(Model):
sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric
)
self.save_hyperparameters("num_layers")
self.save_hyperparameters("depth")
self.encoders = nn.ModuleList()
self.decoders = nn.ModuleList()
out_channels = initial_output_channels
for layer in range(num_layers):
for layer in range(depth):
encoder = WavenetEncoder(num_channels,out_channels)
self.encoders.append(encoder)
num_channels = out_channels
out_channels += initial_output_channels
if layer == num_layers -1 :
decoder = WavenetDecoder(num_layers * initial_output_channels + num_channels,num_channels)
if layer == depth -1 :
decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels)
else:
decoder = WavenetDecoder(num_channels+out_channels,num_channels)
self.decoders.insert(0,decoder)
bottleneck_dim = num_layers * initial_output_channels
bottleneck_dim = depth * initial_output_channels
self.bottleneck = nn.Sequential(
nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1,
padding=7),