fix architecture
This commit is contained in:
parent
c193a48e8e
commit
5eb15b41c4
|
|
@ -49,7 +49,7 @@ class DemucsEncoder(nn.Module):
|
|||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
|
||||
activation,
|
||||
)
|
||||
|
||||
|
|
@ -72,7 +72,7 @@ class DemucsDecoder(nn.Module):
|
|||
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||
multi_factor = 2 if glu else 1
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
|
||||
activation,
|
||||
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
|
||||
)
|
||||
|
|
@ -116,7 +116,7 @@ class Demucs(Model):
|
|||
ED_DEFAULTS = {
|
||||
"initial_output_channels": 48,
|
||||
"kernel_size": 8,
|
||||
"stride": 1,
|
||||
"stride": 4,
|
||||
"depth": 5,
|
||||
"glu": True,
|
||||
"growth_factor": 2,
|
||||
|
|
@ -179,7 +179,7 @@ class Demucs(Model):
|
|||
num_channels=num_channels,
|
||||
hidden_size=hidden,
|
||||
kernel_size=encoder_decoder["kernel_size"],
|
||||
stride=1,
|
||||
stride=encoder_decoder["stride"],
|
||||
glu=encoder_decoder["glu"],
|
||||
layer=layer,
|
||||
)
|
||||
|
|
@ -236,7 +236,7 @@ class Demucs(Model):
|
|||
self.hparams.sampling_rate,
|
||||
)
|
||||
|
||||
return x
|
||||
return x[..., :length]
|
||||
|
||||
def get_padding_length(self, input_length):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue