fix architecture

This commit is contained in:
shahules786 2022-10-11 11:11:16 +05:30
parent c193a48e8e
commit 5eb15b41c4
1 changed files with 5 additions and 5 deletions

View File

@ -49,7 +49,7 @@ class DemucsEncoder(nn.Module):
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(num_channels, hidden_size, kernel_size, stride), nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
activation, activation,
) )
@ -72,7 +72,7 @@ class DemucsDecoder(nn.Module):
activation = nn.GLU(1) if glu else nn.ReLU() activation = nn.GLU(1) if glu else nn.ReLU()
multi_factor = 2 if glu else 1 multi_factor = 2 if glu else 1
self.decoder = nn.Sequential( self.decoder = nn.Sequential(
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
activation, activation,
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
) )
@ -116,7 +116,7 @@ class Demucs(Model):
ED_DEFAULTS = { ED_DEFAULTS = {
"initial_output_channels": 48, "initial_output_channels": 48,
"kernel_size": 8, "kernel_size": 8,
"stride": 1, "stride": 4,
"depth": 5, "depth": 5,
"glu": True, "glu": True,
"growth_factor": 2, "growth_factor": 2,
@ -179,7 +179,7 @@ class Demucs(Model):
num_channels=num_channels, num_channels=num_channels,
hidden_size=hidden, hidden_size=hidden,
kernel_size=encoder_decoder["kernel_size"], kernel_size=encoder_decoder["kernel_size"],
stride=1, stride=encoder_decoder["stride"],
glu=encoder_decoder["glu"], glu=encoder_decoder["glu"],
layer=layer, layer=layer,
) )
@ -236,7 +236,7 @@ class Demucs(Model):
self.hparams.sampling_rate, self.hparams.sampling_rate,
) )
return x return x[..., :length]
def get_padding_length(self, input_length): def get_padding_length(self, input_length):