275 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			275 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
import logging
 | 
						|
import math
 | 
						|
from typing import List, Optional, Union
 | 
						|
 | 
						|
import torch.nn.functional as F
 | 
						|
from torch import nn
 | 
						|
 | 
						|
from enhancer.data.dataset import EnhancerDataset
 | 
						|
from enhancer.models.model import Model
 | 
						|
from enhancer.utils.io import Audio as audio
 | 
						|
from enhancer.utils.utils import merge_dict
 | 
						|
 | 
						|
 | 
						|
class DemucsLSTM(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        input_size: int,
 | 
						|
        hidden_size: int,
 | 
						|
        num_layers: int,
 | 
						|
        bidirectional: bool = True,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.lstm = nn.LSTM(
 | 
						|
            input_size, hidden_size, num_layers, bidirectional=bidirectional
 | 
						|
        )
 | 
						|
        dim = 2 if bidirectional else 1
 | 
						|
        self.linear = nn.Linear(dim * hidden_size, hidden_size)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
 | 
						|
        output, (h, c) = self.lstm(x)
 | 
						|
        output = self.linear(output)
 | 
						|
 | 
						|
        return output, (h, c)
 | 
						|
 | 
						|
 | 
						|
class DemucsEncoder(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        num_channels: int,
 | 
						|
        hidden_size: int,
 | 
						|
        kernel_size: int,
 | 
						|
        stride: int = 1,
 | 
						|
        glu: bool = False,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        activation = nn.GLU(1) if glu else nn.ReLU()
 | 
						|
        multi_factor = 2 if glu else 1
 | 
						|
        self.encoder = nn.Sequential(
 | 
						|
            nn.Conv1d(num_channels, hidden_size, kernel_size, stride),
 | 
						|
            nn.ReLU(),
 | 
						|
            nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
 | 
						|
            activation,
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, waveform):
 | 
						|
 | 
						|
        return self.encoder(waveform)
 | 
						|
 | 
						|
 | 
						|
class DemucsDecoder(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        num_channels: int,
 | 
						|
        hidden_size: int,
 | 
						|
        kernel_size: int,
 | 
						|
        stride: int = 1,
 | 
						|
        glu: bool = False,
 | 
						|
        layer: int = 0,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        activation = nn.GLU(1) if glu else nn.ReLU()
 | 
						|
        multi_factor = 2 if glu else 1
 | 
						|
        self.decoder = nn.Sequential(
 | 
						|
            nn.Conv1d(hidden_size, hidden_size * multi_factor, 1, 1),
 | 
						|
            activation,
 | 
						|
            nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride),
 | 
						|
        )
 | 
						|
        if layer > 0:
 | 
						|
            self.decoder.add_module("4", nn.ReLU())
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        waveform,
 | 
						|
    ):
 | 
						|
 | 
						|
        out = self.decoder(waveform)
 | 
						|
        return out
 | 
						|
 | 
						|
 | 
						|
class Demucs(Model):
 | 
						|
    """
 | 
						|
    Demucs model from https://arxiv.org/pdf/1911.13254.pdf
 | 
						|
    parameters:
 | 
						|
        encoder_decoder: dict, optional
 | 
						|
            keyword arguments passsed to encoder decoder block
 | 
						|
        lstm : dict, optional
 | 
						|
            keyword arguments passsed to LSTM block
 | 
						|
        num_channels: int, defaults to 1
 | 
						|
            number channels in input audio
 | 
						|
        sampling_rate: int, defaults to 16KHz
 | 
						|
            sampling rate of input audio
 | 
						|
        lr : float, defaults to 1e-3
 | 
						|
            learning rate used for training
 | 
						|
        dataset: EnhancerDataset, optional
 | 
						|
            EnhancerDataset object containing train/validation data for training
 | 
						|
        duration : float, optional
 | 
						|
            chunk duration in seconds
 | 
						|
        loss : string or List of strings
 | 
						|
            loss function to be used, available ("mse","mae","SI-SDR")
 | 
						|
        metric : string or List of strings
 | 
						|
            metric function to be used, available ("mse","mae","SI-SDR")
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    ED_DEFAULTS = {
 | 
						|
        "initial_output_channels": 48,
 | 
						|
        "kernel_size": 8,
 | 
						|
        "stride": 4,
 | 
						|
        "depth": 5,
 | 
						|
        "glu": True,
 | 
						|
        "growth_factor": 2,
 | 
						|
    }
 | 
						|
    LSTM_DEFAULTS = {
 | 
						|
        "bidirectional": True,
 | 
						|
        "num_layers": 2,
 | 
						|
    }
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        encoder_decoder: Optional[dict] = None,
 | 
						|
        lstm: Optional[dict] = None,
 | 
						|
        num_channels: int = 1,
 | 
						|
        resample: int = 4,
 | 
						|
        sampling_rate=16000,
 | 
						|
        normalize=True,
 | 
						|
        lr: float = 1e-3,
 | 
						|
        dataset: Optional[EnhancerDataset] = None,
 | 
						|
        loss: Union[str, List] = "mse",
 | 
						|
        metric: Union[str, List] = "mse",
 | 
						|
        floor=1e-3,
 | 
						|
    ):
 | 
						|
        duration = (
 | 
						|
            dataset.duration if isinstance(dataset, EnhancerDataset) else None
 | 
						|
        )
 | 
						|
        if dataset is not None:
 | 
						|
            if sampling_rate != dataset.sampling_rate:
 | 
						|
                logging.warning(
 | 
						|
                    f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
 | 
						|
                )
 | 
						|
                sampling_rate = dataset.sampling_rate
 | 
						|
        super().__init__(
 | 
						|
            num_channels=num_channels,
 | 
						|
            sampling_rate=sampling_rate,
 | 
						|
            lr=lr,
 | 
						|
            dataset=dataset,
 | 
						|
            duration=duration,
 | 
						|
            loss=loss,
 | 
						|
            metric=metric,
 | 
						|
        )
 | 
						|
 | 
						|
        encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
 | 
						|
        lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
 | 
						|
        self.save_hyperparameters("encoder_decoder", "lstm", "resample")
 | 
						|
        hidden = encoder_decoder["initial_output_channels"]
 | 
						|
        self.normalize = normalize
 | 
						|
        self.floor = floor
 | 
						|
        self.encoder = nn.ModuleList()
 | 
						|
        self.decoder = nn.ModuleList()
 | 
						|
 | 
						|
        for layer in range(encoder_decoder["depth"]):
 | 
						|
 | 
						|
            encoder_layer = DemucsEncoder(
 | 
						|
                num_channels=num_channels,
 | 
						|
                hidden_size=hidden,
 | 
						|
                kernel_size=encoder_decoder["kernel_size"],
 | 
						|
                stride=encoder_decoder["stride"],
 | 
						|
                glu=encoder_decoder["glu"],
 | 
						|
            )
 | 
						|
            self.encoder.append(encoder_layer)
 | 
						|
 | 
						|
            decoder_layer = DemucsDecoder(
 | 
						|
                num_channels=num_channels,
 | 
						|
                hidden_size=hidden,
 | 
						|
                kernel_size=encoder_decoder["kernel_size"],
 | 
						|
                stride=encoder_decoder["stride"],
 | 
						|
                glu=encoder_decoder["glu"],
 | 
						|
                layer=layer,
 | 
						|
            )
 | 
						|
            self.decoder.insert(0, decoder_layer)
 | 
						|
 | 
						|
            num_channels = hidden
 | 
						|
            hidden = self.ED_DEFAULTS["growth_factor"] * hidden
 | 
						|
 | 
						|
        self.de_lstm = DemucsLSTM(
 | 
						|
            input_size=num_channels,
 | 
						|
            hidden_size=num_channels,
 | 
						|
            num_layers=lstm["num_layers"],
 | 
						|
            bidirectional=lstm["bidirectional"],
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, waveform):
 | 
						|
 | 
						|
        if waveform.dim() == 2:
 | 
						|
            waveform = waveform.unsqueeze(1)
 | 
						|
 | 
						|
        if waveform.size(1) != 1:
 | 
						|
            raise TypeError(
 | 
						|
                f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
 | 
						|
            )
 | 
						|
        if self.normalize:
 | 
						|
            waveform = waveform.mean(dim=1, keepdim=True)
 | 
						|
            std = waveform.std(dim=-1, keepdim=True)
 | 
						|
            waveform = waveform / (self.floor + std)
 | 
						|
        else:
 | 
						|
            std = 1
 | 
						|
        length = waveform.shape[-1]
 | 
						|
        x = F.pad(waveform, (0, self.get_padding_length(length) - length))
 | 
						|
        if self.hparams.resample > 1:
 | 
						|
            x = audio.resample_audio(
 | 
						|
                audio=x,
 | 
						|
                sr=self.hparams.sampling_rate,
 | 
						|
                target_sr=int(
 | 
						|
                    self.hparams.sampling_rate * self.hparams.resample
 | 
						|
                ),
 | 
						|
            )
 | 
						|
 | 
						|
        encoder_outputs = []
 | 
						|
        for encoder in self.encoder:
 | 
						|
            x = encoder(x)
 | 
						|
            encoder_outputs.append(x)
 | 
						|
        x = x.permute(0, 2, 1)
 | 
						|
        x, _ = self.de_lstm(x)
 | 
						|
 | 
						|
        x = x.permute(0, 2, 1)
 | 
						|
        for decoder in self.decoder:
 | 
						|
            skip_connection = encoder_outputs.pop(-1)
 | 
						|
            x = x + skip_connection[..., : x.shape[-1]]
 | 
						|
            x = decoder(x)
 | 
						|
 | 
						|
        if self.hparams.resample > 1:
 | 
						|
            x = audio.resample_audio(
 | 
						|
                x,
 | 
						|
                int(self.hparams.sampling_rate * self.hparams.resample),
 | 
						|
                self.hparams.sampling_rate,
 | 
						|
            )
 | 
						|
 | 
						|
        out = x[..., :length]
 | 
						|
        return std * out
 | 
						|
 | 
						|
    def get_padding_length(self, input_length):
 | 
						|
 | 
						|
        input_length = math.ceil(input_length * self.hparams.resample)
 | 
						|
 | 
						|
        for layer in range(
 | 
						|
            self.hparams.encoder_decoder["depth"]
 | 
						|
        ):  # encoder operation
 | 
						|
            input_length = (
 | 
						|
                math.ceil(
 | 
						|
                    (input_length - self.hparams.encoder_decoder["kernel_size"])
 | 
						|
                    / self.hparams.encoder_decoder["stride"]
 | 
						|
                )
 | 
						|
                + 1
 | 
						|
            )
 | 
						|
            input_length = max(1, input_length)
 | 
						|
        for layer in range(
 | 
						|
            self.hparams.encoder_decoder["depth"]
 | 
						|
        ):  # decoder operaration
 | 
						|
            input_length = (input_length - 1) * self.hparams.encoder_decoder[
 | 
						|
                "stride"
 | 
						|
            ] + self.hparams.encoder_decoder["kernel_size"]
 | 
						|
        input_length = math.ceil(input_length / self.hparams.resample)
 | 
						|
 | 
						|
        return int(input_length)
 |