fix shapes
This commit is contained in:
parent
8a43354cb0
commit
b42ca28851
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import bool
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from enhancer.utils.io import Audio as audio
|
from enhancer.utils.io import Audio as audio
|
||||||
|
|
@ -14,6 +13,7 @@ class DeLSTM(nn.Module):
|
||||||
bidirectional:bool=True
|
bidirectional:bool=True
|
||||||
|
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
|
||||||
dim = 2 if bidirectional else 1
|
dim = 2 if bidirectional else 1
|
||||||
self.linear = nn.Linear(dim*hidden_size,hidden_size)
|
self.linear = nn.Linear(dim*hidden_size,hidden_size)
|
||||||
|
|
@ -25,7 +25,7 @@ class DeLSTM(nn.Module):
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class Demus(nn.Module):
|
class Demucs(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -35,10 +35,10 @@ class Demus(nn.Module):
|
||||||
kernel_size:int=8,
|
kernel_size:int=8,
|
||||||
stride:int=4,
|
stride:int=4,
|
||||||
growth_factor:int=2,
|
growth_factor:int=2,
|
||||||
depth:int = 6,
|
depth:int = 5,
|
||||||
glu:bool = True,
|
glu:bool = True,
|
||||||
bidirectional:bool=True,
|
bidirectional:bool=True,
|
||||||
resample:int=2,
|
resample:int=4,
|
||||||
sampling_rate = 16000
|
sampling_rate = 16000
|
||||||
|
|
||||||
):
|
):
|
||||||
|
|
@ -65,8 +65,8 @@ class Demus(nn.Module):
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1),
|
nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1),
|
||||||
self.activation]
|
self.activation]
|
||||||
encoder_layer = nn.Sequential(encoder_layer)
|
encoder_layer = nn.Sequential(*encoder_layer)
|
||||||
self.encoder.append(*encoder_layer)
|
self.encoder.append(encoder_layer)
|
||||||
|
|
||||||
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1),
|
decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1),
|
||||||
self.activation,
|
self.activation,
|
||||||
|
|
@ -87,25 +87,27 @@ class Demus(nn.Module):
|
||||||
def forward(self,mixed_signal):
|
def forward(self,mixed_signal):
|
||||||
|
|
||||||
length = mixed_signal.shape[-1]
|
length = mixed_signal.shape[-1]
|
||||||
x = F.pad((0,self.get_padding_length(length) - length))
|
x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))
|
||||||
if self.resample>1:
|
if self.resample>1:
|
||||||
x = audio.resample_audio(audio=x,
|
x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate,
|
||||||
sampling_rate = int(self.sampling_rate * self.resample))
|
target_sr=int(self.sampling_rate * self.resample))
|
||||||
|
print("resampled->",x.shape)
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for encoder in self.encoder:
|
for encoder in self.encoder:
|
||||||
x = encoder(x)
|
x = encoder(x)
|
||||||
|
print(x.shape)
|
||||||
encoder_outputs.append(x)
|
encoder_outputs.append(x)
|
||||||
|
x = x.permute(0,2,1)
|
||||||
x,_ = self.de_lstm(x)
|
x = self.de_lstm(x)
|
||||||
|
|
||||||
|
x = x.permute(0,2,1)
|
||||||
for decoder in self.decoder:
|
for decoder in self.decoder:
|
||||||
skip_connection = encoder_outputs.pop(-1)
|
skip_connection = encoder_outputs.pop(-1)
|
||||||
x += skip_connection[..., :x.shape[-1]]
|
x += skip_connection[..., :x.shape[-1]]
|
||||||
x = decoder(x)
|
x = decoder(x)
|
||||||
|
|
||||||
if self.resample > 1:
|
if self.resample > 1:
|
||||||
x = audio.resample_audio(x,int(self.sampling_rate * self.resample),
|
x = audio.pt_resample_audio(x,int(self.sampling_rate * self.resample),
|
||||||
self.sampling_rate)
|
self.sampling_rate)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue