demucs forward

This commit is contained in:
shahules786 2022-09-05 17:12:03 +05:30
parent 9df1dafccf
commit 409afc31fc
1 changed files with 51 additions and 4 deletions

View File

@ -1,5 +1,9 @@
from typing import bool from typing import bool
from torch import nn from torch import nn
import torch.functional as F
import math
from enhancer.utils.io import Audio as audio
class DeLSTM(nn.Module): class DeLSTM(nn.Module):
def __init__( def __init__(
@ -35,8 +39,10 @@ class Demus(nn.Module):
glu:bool = True, glu:bool = True,
bidirectional:bool=True, bidirectional:bool=True,
resample:int=2, resample:int=2,
sampling_rate = 16000
): ):
super().__init__()
self.c_in = c_in self.c_in = c_in
self.c_out = c_out self.c_out = c_out
self.hidden = hidden self.hidden = hidden
@ -46,10 +52,10 @@ class Demus(nn.Module):
self.depth = depth self.depth = depth
self.bidirectional = bidirectional self.bidirectional = bidirectional
self.activation = nn.GLU(1) if glu else nn.ReLU() self.activation = nn.GLU(1) if glu else nn.ReLU()
self.resample = resample
self.sampling_rate = sampling_rate
multi_factor = 2 if glu else 1 multi_factor = 2 if glu else 1
## do resampling
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
@ -78,7 +84,48 @@ class Demus(nn.Module):
self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional)
def forward(self,input): def forward(self,mixed_signal):
length = mixed_signal.shape[-1]
x = F.pad((0,self.get_padding_length(length) - length))
if self.resample>1:
x = audio.resample_audio(audio=x,
sampling_rate = int(self.sampling_rate * self.resample))
encoder_outputs = []
for encoder in self.encoder:
x = encoder(x)
encoder_outputs.append(x)
x,_ = self.de_lstm(x)
for decoder in self.decoder:
skip_connection = encoder_outputs.pop(-1)
x += skip_connection[..., :x.shape[-1]]
x = decoder(x)
if self.resample > 1:
x = audio.resample_audio(x,int(self.sampling_rate * self.resample),
self.sampling_rate)
return x
def get_padding_length(self,input_length):
input_length = math.ceil(input_length * self.resample)
for layer in range(self.depth): # encoder operation
input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1
input_length = max(1,input_length)
for layer in range(self.depth): # decoder operaration
input_length = (input_length-1) * self.stride + self.kernel_size
input_length = math.ceil(input_length/self.resample)
return int(input_length)