245 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			245 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| import logging
 | |
| from typing import Any, List, Optional, Tuple, Union
 | |
| 
 | |
| from torch import nn
 | |
| 
 | |
| from enhancer.data import EnhancerDataset
 | |
| from enhancer.models import Model
 | |
| from enhancer.models.complexnn import (
 | |
|     ComplexBatchNorm2D,
 | |
|     ComplexConv2d,
 | |
|     ComplexConvTranspose2d,
 | |
|     ComplexLSTM,
 | |
|     ComplexRelu,
 | |
| )
 | |
| from enhancer.utils.transforms import ConviSTFT, ConvSTFT
 | |
| from enhancer.utils.utils import merge_dict
 | |
| 
 | |
| 
 | |
| class DCCRN_ENCODER(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels: int,
 | |
|         out_channel: int,
 | |
|         kernel_size: Tuple[int, int],
 | |
|         complex_norm: bool = True,
 | |
|         complex_relu: bool = True,
 | |
|         stride: Tuple[int, int] = (2, 1),
 | |
|         padding: Tuple[int, int] = (2, 1),
 | |
|     ):
 | |
|         super().__init__()
 | |
|         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
 | |
|         activation = ComplexRelu() if complex_relu else nn.PReLU()
 | |
| 
 | |
|         self.encoder = nn.Sequential(
 | |
|             ComplexConv2d(
 | |
|                 in_channels,
 | |
|                 out_channel,
 | |
|                 kernel_size=kernel_size,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|             ),
 | |
|             batchnorm(out_channel),
 | |
|             activation,
 | |
|         )
 | |
| 
 | |
|     def forward(self, waveform):
 | |
| 
 | |
|         return self.encoder(waveform)
 | |
| 
 | |
| 
 | |
| class DCCRN_DECODER(nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         in_channels: int,
 | |
|         out_channels: int,
 | |
|         kernel_size: Tuple[int, int],
 | |
|         complex_norm: bool = True,
 | |
|         complex_relu: bool = True,
 | |
|         stride: Tuple[int, int] = (2, 1),
 | |
|         padding: Tuple[int, int] = (2, 0),
 | |
|         output_padding: Tuple[int, int] = (1, 0),
 | |
|     ):
 | |
|         super().__init__()
 | |
|         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d
 | |
|         activation = ComplexRelu() if complex_relu else nn.PReLU()
 | |
| 
 | |
|         self.decoder = nn.Sequential(
 | |
|             ComplexConvTranspose2d(
 | |
|                 in_channels,
 | |
|                 out_channels,
 | |
|                 kernel_size=kernel_size,
 | |
|                 stride=stride,
 | |
|                 padding=padding,
 | |
|                 output_padding=output_padding,
 | |
|             ),
 | |
|             batchnorm(out_channels),
 | |
|             activation,
 | |
|         )
 | |
| 
 | |
|     def forward(self, waveform):
 | |
| 
 | |
|         return self.decoder(waveform)
 | |
| 
 | |
| 
 | |
| class DCCRN(Model):
 | |
| 
 | |
|     STFT_DEFAULTS = {
 | |
|         "window_len": 400,
 | |
|         "hop_size": 100,
 | |
|         "nfft": 512,
 | |
|         "window": "hamming",
 | |
|     }
 | |
| 
 | |
|     ED_DEFAULTS = {
 | |
|         "initial_output_channels": 32,
 | |
|         "depth": 6,
 | |
|         "kernel_size": 5,
 | |
|         "growth_factor": 2,
 | |
|         "stride": 2,
 | |
|         "padding": 2,
 | |
|         "output_padding": 1,
 | |
|     }
 | |
| 
 | |
|     LSTM_DEFAULTS = {
 | |
|         "num_layers": 2,
 | |
|         "hidden_size": 256,
 | |
|     }
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         stft: Optional[dict] = None,
 | |
|         encoder_decoder: Optional[dict] = None,
 | |
|         lstm: Optional[dict] = None,
 | |
|         complex_lstm: bool = True,
 | |
|         complex_norm: bool = True,
 | |
|         complex_relu: bool = True,
 | |
|         masking_mode: str = "E",
 | |
|         num_channels: int = 1,
 | |
|         sampling_rate=16000,
 | |
|         lr: float = 1e-3,
 | |
|         dataset: Optional[EnhancerDataset] = None,
 | |
|         duration: Optional[float] = None,
 | |
|         loss: Union[str, List, Any] = "mse",
 | |
|         metric: Union[str, List] = "mse",
 | |
|     ):
 | |
|         duration = (
 | |
|             dataset.duration if isinstance(dataset, EnhancerDataset) else None
 | |
|         )
 | |
|         if dataset is not None:
 | |
|             if sampling_rate != dataset.sampling_rate:
 | |
|                 logging.warning(
 | |
|                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
 | |
|                 )
 | |
|                 sampling_rate = dataset.sampling_rate
 | |
|         super().__init__(
 | |
|             num_channels=num_channels,
 | |
|             sampling_rate=sampling_rate,
 | |
|             lr=lr,
 | |
|             dataset=dataset,
 | |
|             duration=duration,
 | |
|             loss=loss,
 | |
|             metric=metric,
 | |
|         )
 | |
| 
 | |
|         encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder)
 | |
|         lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
 | |
|         stft = merge_dict(self.STFT_DEFAULTS, stft)
 | |
|         self.save_hyperparameters(
 | |
|             "encoder_decoder",
 | |
|             "lstm",
 | |
|             "stft",
 | |
|             "complex_lstm",
 | |
|             "complex_norm",
 | |
|             "masking_mode",
 | |
|         )
 | |
|         self.complex_lstm = complex_lstm
 | |
|         self.complex_norm = complex_norm
 | |
|         self.masking_mode = masking_mode
 | |
| 
 | |
|         self.stft = ConvSTFT(
 | |
|             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
 | |
|         )
 | |
|         self.istft = ConviSTFT(
 | |
|             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"]
 | |
|         )
 | |
| 
 | |
|         self.encoder = nn.ModuleList()
 | |
|         self.decoder = nn.ModuleList()
 | |
| 
 | |
|         num_channels *= 2
 | |
|         hidden_size = encoder_decoder["initial_output_channels"]
 | |
|         growth_factor = 2
 | |
| 
 | |
|         for layer in range(encoder_decoder["depth"]):
 | |
| 
 | |
|             encoder_ = DCCRN_ENCODER(
 | |
|                 num_channels,
 | |
|                 hidden_size,
 | |
|                 kernel_size=(encoder_decoder["kernel_size"], 2),
 | |
|                 stride=(encoder_decoder["stride"], 1),
 | |
|                 padding=(encoder_decoder["padding"], 1),
 | |
|                 complex_norm=complex_norm,
 | |
|                 complex_relu=complex_relu,
 | |
|             )
 | |
|             self.encoder.append(encoder_)
 | |
| 
 | |
|             decoder_ = DCCRN_DECODER(
 | |
|                 hidden_size + hidden_size,
 | |
|                 num_channels,
 | |
|                 kernel_size=(encoder_decoder["kernel_size"], 2),
 | |
|                 stride=(encoder_decoder["stride"], 1),
 | |
|                 padding=(encoder_decoder["padding"], 0),
 | |
|                 output_padding=(encoder_decoder["output_padding"], 0),
 | |
|                 complex_norm=complex_norm,
 | |
|                 complex_relu=complex_relu,
 | |
|             )
 | |
| 
 | |
|             self.decoder.insert(0, decoder_)
 | |
| 
 | |
|             if layer < encoder_decoder["depth"] - 3:
 | |
|                 num_channels = hidden_size
 | |
|                 hidden_size *= growth_factor
 | |
|             else:
 | |
|                 num_channels = hidden_size
 | |
| 
 | |
|         kernel_size = hidden_size / 2
 | |
|         hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"])
 | |
| 
 | |
|         if self.complex_lstm:
 | |
|             lstms = []
 | |
|             for layer in range(lstm["num_layers"]):
 | |
| 
 | |
|                 if layer == 0:
 | |
|                     input_size = int(hidden_size * kernel_size)
 | |
|                 else:
 | |
|                     input_size = lstm["hidden_size"]
 | |
| 
 | |
|                 if layer == lstm["num_layers"] - 1:
 | |
|                     projection_size = int(hidden_size * kernel_size)
 | |
|                 else:
 | |
|                     projection_size = None
 | |
| 
 | |
|                 kwargs = {
 | |
|                     "input_size": input_size,
 | |
|                     "hidden_size": lstm["hidden_size"],
 | |
|                     "num_layers": 1,
 | |
|                 }
 | |
| 
 | |
|                 lstms.append(
 | |
|                     ComplexLSTM(projection_size=projection_size, **kwargs)
 | |
|                 )
 | |
|             self.lstm = nn.Sequential(*lstms)
 | |
|         else:
 | |
|             self.lstm = nn.LSTM(
 | |
|                 input_size=hidden_size * kernel_size,
 | |
|                 hidden_sizs=lstm["hidden_size"],
 | |
|                 num_layers=lstm["num_layers"],
 | |
|                 dropout=0.0,
 | |
|                 batch_first=False,
 | |
|             )
 | |
| 
 | |
|     def forward(self, waveform):
 | |
| 
 | |
|         return waveform
 |