return hidden & cell states
This commit is contained in:
parent
fb6d0afee0
commit
4edc90deb0
|
|
@ -23,7 +23,7 @@ class DeLSTM(nn.Module):
|
||||||
output,(h,c) = self.lstm(x)
|
output,(h,c) = self.lstm(x)
|
||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
|
|
||||||
return output
|
return output,(h,c)
|
||||||
|
|
||||||
class Demucs(nn.Module):
|
class Demucs(nn.Module):
|
||||||
|
|
||||||
|
|
@ -85,20 +85,23 @@ class Demucs(nn.Module):
|
||||||
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
|
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
|
||||||
|
|
||||||
def forward(self,mixed_signal):
|
def forward(self,mixed_signal):
|
||||||
|
|
||||||
|
if mixed_signal.dim() == 2:
|
||||||
|
mixed_signal = mixed_signal.unsqueeze(1)
|
||||||
|
|
||||||
length = mixed_signal.shape[-1]
|
length = mixed_signal.shape[-1]
|
||||||
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
|
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
|
||||||
if self.resample>1:
|
if self.resample>1:
|
||||||
x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate,
|
x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate,
|
||||||
target_sr=int(self.sampling_rate * self.resample))
|
target_sr=int(self.sampling_rate * self.resample))
|
||||||
print("resampled->",x.shape)
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for encoder in self.encoder:
|
for encoder in self.encoder:
|
||||||
x = encoder(x)
|
x = encoder(x)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
encoder_outputs.append(x)
|
encoder_outputs.append(x)
|
||||||
x = x.permute(0,2,1)
|
x = x.permute(0,2,1)
|
||||||
x = self.de_lstm(x)
|
x,_ = self.de_lstm(x)
|
||||||
|
|
||||||
x = x.permute(0,2,1)
|
x = x.permute(0,2,1)
|
||||||
for decoder in self.decoder:
|
for decoder in self.decoder:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue