diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index bf9d429..5d7e99f 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -49,7 +49,7 @@ class DemucsEncoder(nn.Module): self.encoder = nn.Sequential( nn.Conv1d(num_channels, hidden_size, kernel_size, stride), nn.ReLU(), - nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), + nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), activation, ) @@ -72,7 +72,7 @@ class DemucsDecoder(nn.Module): activation = nn.GLU(1) if glu else nn.ReLU() multi_factor = 2 if glu else 1 self.decoder = nn.Sequential( - nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), + nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1), activation, nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), ) @@ -116,7 +116,7 @@ class Demucs(Model): ED_DEFAULTS = { "initial_output_channels": 48, "kernel_size": 8, - "stride": 1, + "stride": 4, "depth": 5, "glu": True, "growth_factor": 2, @@ -179,7 +179,7 @@ class Demucs(Model): num_channels=num_channels, hidden_size=hidden, kernel_size=encoder_decoder["kernel_size"], - stride=1, + stride=encoder_decoder["stride"], glu=encoder_decoder["glu"], layer=layer, ) @@ -236,7 +236,7 @@ class Demucs(Model): self.hparams.sampling_rate, ) - return x + return x[..., :length] def get_padding_length(self, input_length):