156 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			156 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| from typing import Optional, Union, List
 | |
| from torch import nn
 | |
| import torch.nn.functional as F
 | |
| import math 
 | |
| 
 | |
| from enhancer.models.model import Model
 | |
| from enhancer.data.dataset import EnhancerDataset
 | |
| from enhancer.utils.io import Audio as audio
 | |
| from enhancer.utils.utils import merge_dict
 | |
| 
 | |
| class DeLSTM(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_size:int,
 | |
|         hidden_size:int,
 | |
|         num_layers:int,
 | |
|         bidirectional:bool=True
 | |
| 
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
 | |
|         dim = 2 if bidirectional else 1
 | |
|         self.linear = nn.Linear(dim*hidden_size,hidden_size)
 | |
| 
 | |
|     def forward(self,x):
 | |
| 
 | |
|         output,(h,c) = self.lstm(x)
 | |
|         output = self.linear(output)
 | |
| 
 | |
|         return output,(h,c)
 | |
| 
 | |
| class Demucs(Model):
 | |
| 
 | |
|     ED_DEFAULTS = {
 | |
|         "initial_output_channels":48,
 | |
|         "kernel_size":8,
 | |
|         "stride":1,
 | |
|         "depth":5,
 | |
|         "glu":True,
 | |
|         "growth_factor":2,
 | |
|     }
 | |
|     LSTM_DEFAULTS = {
 | |
|         "bidirectional":True,
 | |
|         "num_layers":2,
 | |
|     }
 | |
|     
 | |
|     def __init__(
 | |
|         self,
 | |
|         encoder_decoder:Optional[dict]=None,
 | |
|         lstm:Optional[dict]=None,
 | |
|         num_channels:int=1,
 | |
|         resample:int=4,
 | |
|         sampling_rate = 16000,
 | |
|         lr:float=1e-3,
 | |
|         dataset:Optional[EnhancerDataset]=None,
 | |
|         loss:Union[str, List] = "mse"
 | |
| 
 | |
|     ):
 | |
|         super().__init__(num_channels=num_channels,
 | |
|                             sampling_rate=sampling_rate,lr=lr,
 | |
|                             dataset=dataset,loss=loss)
 | |
|         
 | |
|         encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder)
 | |
|         lstm = merge_dict(self.LSTM_DEFAULTS,lstm)
 | |
|         self.save_hyperparameters("encoder_decoder","lstm","resample")
 | |
|         
 | |
|         hidden = encoder_decoder["initial_output_channels"]
 | |
|         activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU()
 | |
|         multi_factor = 2 if encoder_decoder["glu"] else 1
 | |
| 
 | |
|         self.encoder = nn.ModuleList()
 | |
|         self.decoder = nn.ModuleList()
 | |
| 
 | |
|         for layer in range(encoder_decoder["depth"]):
 | |
| 
 | |
|             encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]),
 | |
|                             nn.ReLU(),
 | |
|                             nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1),
 | |
|                             activation]
 | |
|             encoder_layer = nn.Sequential(*encoder_layer)
 | |
|             self.encoder.append(encoder_layer)
 | |
| 
 | |
|             decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1),
 | |
|                             activation,
 | |
|                             nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"])
 | |
|                             ]
 | |
|             if layer>0:
 | |
|                 decoder_layer.append(nn.ReLU())
 | |
|             decoder_layer = nn.Sequential(*decoder_layer)
 | |
|             self.decoder.insert(0,decoder_layer)
 | |
| 
 | |
|             num_channels = hidden
 | |
|             hidden = self.ED_DEFAULTS["growth_factor"] * hidden
 | |
| 
 | |
|         
 | |
|         self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"])
 | |
| 
 | |
|     def forward(self,mixed_signal):
 | |
| 
 | |
|         if mixed_signal.dim() == 2:
 | |
|             mixed_signal = mixed_signal.unsqueeze(1)
 | |
| 
 | |
|         if mixed_signal.size(1)!=1:
 | |
|             raise TypeError(f"Demucs can only process mono channel audio, input has {mixed_signal.size(1)} channels")
 | |
| 
 | |
|         length = mixed_signal.shape[-1]
 | |
|         x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length)) 
 | |
|         if self.hparams.resample>1:
 | |
|             x = audio.pt_resample_audio(audio=x, sr=self.hparams.sampling_rate,
 | |
|                         target_sr=int(self.hparams.sampling_rate * self.hparams.resample))
 | |
|         
 | |
|         encoder_outputs = []
 | |
|         for encoder in self.encoder:
 | |
|             x = encoder(x)
 | |
|             encoder_outputs.append(x)
 | |
|         x = x.permute(0,2,1)
 | |
|         x,_ = self.de_lstm(x)
 | |
| 
 | |
|         x = x.permute(0,2,1)
 | |
|         for decoder in self.decoder:
 | |
|             skip_connection = encoder_outputs.pop(-1)
 | |
|             x += skip_connection[..., :x.shape[-1]]
 | |
|             x = decoder(x)
 | |
|         
 | |
|         if self.hparams.resample > 1:
 | |
|             x = audio.pt_resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample),
 | |
|                                     self.hparams.sampling_rate)
 | |
| 
 | |
|         return x
 | |
|         
 | |
|     def get_padding_length(self,input_length):
 | |
| 
 | |
|         input_length = math.ceil(input_length * self.hparams.resample)
 | |
| 
 | |
|   
 | |
|         for layer in range(self.hparams.encoder_decoder["depth"]):                                        # encoder operation
 | |
|             input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1
 | |
|             input_length = max(1,input_length)
 | |
|         for layer in range(self.hparams.encoder_decoder["depth"]):                                        # decoder operaration
 | |
|             input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"]
 | |
|         input_length = math.ceil(input_length/self.hparams.resample)
 | |
| 
 | |
|         return int(input_length)
 | |
| 
 | |
|         
 | |
| 
 | |
|         
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
|          |