demucs forward
This commit is contained in:
parent
9df1dafccf
commit
409afc31fc
|
|
@ -1,5 +1,9 @@
|
||||||
from typing import bool
|
from typing import bool
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
from enhancer.utils.io import Audio as audio
|
||||||
|
|
||||||
class DeLSTM(nn.Module):
|
class DeLSTM(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -35,8 +39,10 @@ class Demus(nn.Module):
|
||||||
glu:bool = True,
|
glu:bool = True,
|
||||||
bidirectional:bool=True,
|
bidirectional:bool=True,
|
||||||
resample:int=2,
|
resample:int=2,
|
||||||
|
sampling_rate = 16000
|
||||||
|
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
self.c_in = c_in
|
self.c_in = c_in
|
||||||
self.c_out = c_out
|
self.c_out = c_out
|
||||||
self.hidden = hidden
|
self.hidden = hidden
|
||||||
|
|
@ -46,10 +52,10 @@ class Demus(nn.Module):
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.bidirectional = bidirectional
|
self.bidirectional = bidirectional
|
||||||
self.activation = nn.GLU(1) if glu else nn.ReLU()
|
self.activation = nn.GLU(1) if glu else nn.ReLU()
|
||||||
|
self.resample = resample
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
multi_factor = 2 if glu else 1
|
multi_factor = 2 if glu else 1
|
||||||
|
|
||||||
## do resampling
|
|
||||||
|
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
|
@ -78,7 +84,48 @@ class Demus(nn.Module):
|
||||||
|
|
||||||
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
|
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
|
||||||
|
|
||||||
def forward(self,input):
|
def forward(self,mixed_signal):
|
||||||
|
|
||||||
|
length = mixed_signal.shape[-1]
|
||||||
|
x = F.pad((0,self.get_padding_length(length) - length))
|
||||||
|
if self.resample>1:
|
||||||
|
x = audio.resample_audio(audio=x,
|
||||||
|
sampling_rate = int(self.sampling_rate * self.resample))
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
for encoder in self.encoder:
|
||||||
|
x = encoder(x)
|
||||||
|
encoder_outputs.append(x)
|
||||||
|
|
||||||
|
x,_ = self.de_lstm(x)
|
||||||
|
|
||||||
|
for decoder in self.decoder:
|
||||||
|
skip_connection = encoder_outputs.pop(-1)
|
||||||
|
x += skip_connection[..., :x.shape[-1]]
|
||||||
|
x = decoder(x)
|
||||||
|
|
||||||
|
if self.resample > 1:
|
||||||
|
x = audio.resample_audio(x,int(self.sampling_rate * self.resample),
|
||||||
|
self.sampling_rate)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_padding_length(self,input_length):
|
||||||
|
|
||||||
|
input_length = math.ceil(input_length * self.resample)
|
||||||
|
|
||||||
|
|
||||||
|
for layer in range(self.depth): # encoder operation
|
||||||
|
input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1
|
||||||
|
input_length = max(1,input_length)
|
||||||
|
for layer in range(self.depth): # decoder operaration
|
||||||
|
input_length = (input_length-1) * self.stride + self.kernel_size
|
||||||
|
input_length = math.ceil(input_length/self.resample)
|
||||||
|
|
||||||
|
return int(input_length)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue