fix architecture
This commit is contained in:
parent
c193a48e8e
commit
5eb15b41c4
|
|
@ -49,7 +49,7 @@ class DemucsEncoder(nn.Module):
|
||||||
self.encoder = nn.Sequential(
|
self.encoder = nn.Sequential(
|
||||||
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
|
nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
|
||||||
activation,
|
activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -72,7 +72,7 @@ class DemucsDecoder(nn.Module):
|
||||||
activation = nn.GLU(1) if glu else nn.ReLU()
|
activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
multi_factor = 2 if glu else 1
|
multi_factor = 2 if glu else 1
|
||||||
self.decoder = nn.Sequential(
|
self.decoder = nn.Sequential(
|
||||||
nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1),
|
nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
|
||||||
activation,
|
activation,
|
||||||
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
|
nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
|
||||||
)
|
)
|
||||||
|
|
@ -116,7 +116,7 @@ class Demucs(Model):
|
||||||
ED_DEFAULTS = {
|
ED_DEFAULTS = {
|
||||||
"initial_output_channels": 48,
|
"initial_output_channels": 48,
|
||||||
"kernel_size": 8,
|
"kernel_size": 8,
|
||||||
"stride": 1,
|
"stride": 4,
|
||||||
"depth": 5,
|
"depth": 5,
|
||||||
"glu": True,
|
"glu": True,
|
||||||
"growth_factor": 2,
|
"growth_factor": 2,
|
||||||
|
|
@ -179,7 +179,7 @@ class Demucs(Model):
|
||||||
num_channels=num_channels,
|
num_channels=num_channels,
|
||||||
hidden_size=hidden,
|
hidden_size=hidden,
|
||||||
kernel_size=encoder_decoder["kernel_size"],
|
kernel_size=encoder_decoder["kernel_size"],
|
||||||
stride=1,
|
stride=encoder_decoder["stride"],
|
||||||
glu=encoder_decoder["glu"],
|
glu=encoder_decoder["glu"],
|
||||||
layer=layer,
|
layer=layer,
|
||||||
)
|
)
|
||||||
|
|
@ -236,7 +236,7 @@ class Demucs(Model):
|
||||||
self.hparams.sampling_rate,
|
self.hparams.sampling_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x[..., :length]
|
||||||
|
|
||||||
def get_padding_length(self, input_length):
|
def get_padding_length(self, input_length):
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue