266 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			266 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
| import logging
 | |
| import math
 | |
| from typing import List, Optional, Union
 | |
| 
 | |
| import torch.nn.functional as F
 | |
| from torch import nn
 | |
| 
 | |
| from enhancer.data.dataset import EnhancerDataset
 | |
| from enhancer.models.model import Model
 | |
| from enhancer.utils.io import Audio as audio
 | |
| from enhancer.utils.utils import merge_dict
 | |
| 
 | |
| 
 | |
| class DemucsLSTM(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_size: int,
 | |
|         hidden_size: int,
 | |
|         num_layers: int,
 | |
|         bidirectional: bool = True,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.lstm = nn.LSTM(
 | |
|             input_size, hidden_size, num_layers, bidirectional=bidirectional
 | |
|         )
 | |
|         dim = 2 if bidirectional else 1
 | |
|         self.linear = nn.Linear(dim * hidden_size, hidden_size)
 | |
| 
 | |
|     def forward(self, x):
 | |
| 
 | |
|         output, (h, c) = self.lstm(x)
 | |
|         output = self.linear(output)
 | |
| 
 | |
|         return output, (h, c)
 | |
| 
 | |
| 
 | |
| class DemucsEncoder(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         num_channels: int,
 | |
|         hidden_size: int,
 | |
|         kernel_size: int,
 | |
|         stride: int = 1,
 | |
|         glu: bool = False,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         activation = nn.GLU(1) if glu else nn.ReLU()
 | |
|         multi_factor = 2 if glu else 1
 | |
|         self.encoder = nn.Sequential(
 | |
|             nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
 | |
|             nn.ReLU(),
 | |
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
 | |
|             activation,
 | |
|         )
 | |
| 
 | |
|     def forward(self, waveform):
 | |
| 
 | |
|         return self.encoder(waveform)
 | |
| 
 | |
| 
 | |
| class DemucsDecoder(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         num_channels: int,
 | |
|         hidden_size: int,
 | |
|         kernel_size: int,
 | |
|         stride: int = 1,
 | |
|         glu: bool = False,
 | |
|         layer: int = 0,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         activation = nn.GLU(1) if glu else nn.ReLU()
 | |
|         multi_factor = 2 if glu else 1
 | |
|         self.decoder = nn.Sequential(
 | |
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
 | |
|             activation,
 | |
|             nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
 | |
|         )
 | |
|         if layer > 0:
 | |
|             self.decoder.add_module("4", nn.ReLU())
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         waveform,
 | |
|     ):
 | |
| 
 | |
|         out = self.decoder(waveform)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| class Demucs(Model):
 | |
|     """
 | |
|     Demucs model from https://arxiv.org/pdf/1911.13254.pdf
 | |
|     parameters:
 | |
|         encoder_decoder: dict, optional
 | |
|             keyword arguments passsed to encoder decoder block
 | |
|         lstm : dict, optional
 | |
|             keyword arguments passsed to LSTM block
 | |
|         num_channels: int, defaults to 1
 | |
|             number channels in input audio
 | |
|         sampling_rate: int, defaults to 16KHz
 | |
|             sampling rate of input audio
 | |
|         lr : float, defaults to 1e-3
 | |
|             learning rate used for training
 | |
|         dataset: EnhancerDataset, optional
 | |
|             EnhancerDataset object containing train/validation data for training
 | |
|         duration : float, optional
 | |
|             chunk duration in seconds
 | |
|         loss : string or List of strings
 | |
|             loss function to be used, available ("mse","mae","SI-SDR")
 | |
|         metric : string or List of strings
 | |
|             metric function to be used, available ("mse","mae","SI-SDR")
 | |
| 
 | |
|     """
 | |
| 
 | |
|     ED_DEFAULTS = {
 | |
|         "initial_output_channels": 48,
 | |
|         "kernel_size": 8,
 | |
|         "stride": 4,
 | |
|         "depth": 5,
 | |
|         "glu": True,
 | |
|         "growth_factor": 2,
 | |
|     }
 | |
|     LSTM_DEFAULTS = {
 | |
|         "bidirectional": True,
 | |
|         "num_layers": 2,
 | |
|     }
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         encoder_decoder: Optional[dict] = None,
 | |
|         lstm: Optional[dict] = None,
 | |
|         num_channels: int = 1,
 | |
|         resample: int = 4,
 | |
|         sampling_rate=16000,
 | |
|         lr: float = 1e-3,
 | |
|         dataset: Optional[EnhancerDataset] = None,
 | |
|         loss: Union[str, List] = "mse",
 | |
|         metric: Union[str, List] = "mse",
 | |
|     ):
 | |
|         duration = (
 | |
|             dataset.duration if isinstance(dataset, EnhancerDataset) else None
 | |
|         )
 | |
|         if dataset is not None:
 | |
|             if sampling_rate != dataset.sampling_rate:
 | |
|                 logging.warning(
 | |
|                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
 | |
|                 )
 | |
|                 sampling_rate = dataset.sampling_rate
 | |
|         super().__init__(
 | |
|             num_channels=num_channels,
 | |
|             sampling_rate=sampling_rate,
 | |
|             lr=lr,
 | |
|             dataset=dataset,
 | |
|             duration=duration,
 | |
|             loss=loss,
 | |
|             metric=metric,
 | |
|         )
 | |
| 
 | |
|         encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
 | |
|         lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
 | |
|         self.save_hyperparameters("encoder_decoder", "lstm", "resample")
 | |
|         hidden = encoder_decoder["initial_output_channels"]
 | |
|         self.encoder = nn.ModuleList()
 | |
|         self.decoder = nn.ModuleList()
 | |
| 
 | |
|         for layer in range(encoder_decoder["depth"]):
 | |
| 
 | |
|             encoder_layer = DemucsEncoder(
 | |
|                 num_channels=num_channels,
 | |
|                 hidden_size=hidden,
 | |
|                 kernel_size=encoder_decoder["kernel_size"],
 | |
|                 stride=encoder_decoder["stride"],
 | |
|                 glu=encoder_decoder["glu"],
 | |
|             )
 | |
|             self.encoder.append(encoder_layer)
 | |
| 
 | |
|             decoder_layer = DemucsDecoder(
 | |
|                 num_channels=num_channels,
 | |
|                 hidden_size=hidden,
 | |
|                 kernel_size=encoder_decoder["kernel_size"],
 | |
|                 stride=encoder_decoder["stride"],
 | |
|                 glu=encoder_decoder["glu"],
 | |
|                 layer=layer,
 | |
|             )
 | |
|             self.decoder.insert(0, decoder_layer)
 | |
| 
 | |
|             num_channels = hidden
 | |
|             hidden = self.ED_DEFAULTS["growth_factor"] * hidden
 | |
| 
 | |
|         self.de_lstm = DemucsLSTM(
 | |
|             input_size=num_channels,
 | |
|             hidden_size=num_channels,
 | |
|             num_layers=lstm["num_layers"],
 | |
|             bidirectional=lstm["bidirectional"],
 | |
|         )
 | |
| 
 | |
|     def forward(self, waveform):
 | |
| 
 | |
|         if waveform.dim() == 2:
 | |
|             waveform = waveform.unsqueeze(1)
 | |
| 
 | |
|         if waveform.size(1) != 1:
 | |
|             raise TypeError(
 | |
|                 f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
 | |
|             )
 | |
| 
 | |
|         length = waveform.shape[-1]
 | |
|         x = F.pad(waveform, (0, self.get_padding_length(length) - length))
 | |
|         if self.hparams.resample > 1:
 | |
|             x = audio.resample_audio(
 | |
|                 audio=x,
 | |
|                 sr=self.hparams.sampling_rate,
 | |
|                 target_sr=int(
 | |
|                     self.hparams.sampling_rate * self.hparams.resample
 | |
|                 ),
 | |
|             )
 | |
| 
 | |
|         encoder_outputs = []
 | |
|         for encoder in self.encoder:
 | |
|             x = encoder(x)
 | |
|             encoder_outputs.append(x)
 | |
|         x = x.permute(0, 2, 1)
 | |
|         x, _ = self.de_lstm(x)
 | |
| 
 | |
|         x = x.permute(0, 2, 1)
 | |
|         for decoder in self.decoder:
 | |
|             skip_connection = encoder_outputs.pop(-1)
 | |
|             x = x + skip_connection[..., : x.shape[-1]]
 | |
|             x = decoder(x)
 | |
| 
 | |
|         if self.hparams.resample > 1:
 | |
|             x = audio.resample_audio(
 | |
|                 x,
 | |
|                 int(self.hparams.sampling_rate * self.hparams.resample),
 | |
|                 self.hparams.sampling_rate,
 | |
|             )
 | |
| 
 | |
|         out = x[..., :length]
 | |
|         return out
 | |
| 
 | |
|     def get_padding_length(self, input_length):
 | |
| 
 | |
|         input_length = math.ceil(input_length * self.hparams.resample)
 | |
| 
 | |
|         for layer in range(
 | |
|             self.hparams.encoder_decoder["depth"]
 | |
|         ):  # encoder operation
 | |
|             input_length = (
 | |
|                 math.ceil(
 | |
|                     (input_length - self.hparams.encoder_decoder["kernel_size"])
 | |
|                     / self.hparams.encoder_decoder["stride"]
 | |
|                 )
 | |
|                 + 1
 | |
|             )
 | |
|             input_length = max(1, input_length)
 | |
|         for layer in range(
 | |
|             self.hparams.encoder_decoder["depth"]
 | |
|         ):  # decoder operaration
 | |
|             input_length = (input_length - 1) * self.hparams.encoder_decoder[
 | |
|                 "stride"
 | |
|             ] + self.hparams.encoder_decoder["kernel_size"]
 | |
|         input_length = math.ceil(input_length / self.hparams.resample)
 | |
| 
 | |
|         return int(input_length)
 |