init dccrn
This commit is contained in:
		
							parent
							
								
									d3e052c5f3
								
							
						
					
					
						commit
						981763207a
					
				|  | @ -0,0 +1,240 @@ | |||
| import logging | ||||
| from typing import Any, List, Optional, Tuple, Union | ||||
| 
 | ||||
| from torch import nn | ||||
| 
 | ||||
| from enhancer.data import EnhancerDataset | ||||
| from enhancer.models import Model | ||||
| from enhancer.models.complexnn import ComplexConv2d, ComplexLSTM | ||||
| from enhancer.models.complexnn.conv import ComplexConvTranspose2d | ||||
| from enhancer.models.complexnn.utils import ComplexBatchNorm2D, ComplexRelu | ||||
| from enhancer.utils.transforms import ConviSTFT, ConvSTFT | ||||
| from enhancer.utils.utils import merge_dict | ||||
| 
 | ||||
| 
 | ||||
| class DCCRN_ENCODER(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_channels: int, | ||||
|         out_channel: int, | ||||
|         kernel_size: Tuple[int, int], | ||||
|         complex_norm: bool = True, | ||||
|         complex_relu: bool = True, | ||||
|         stride: Tuple[int, int] = (2, 1), | ||||
|         padding: Tuple[int, int] = (2, 1), | ||||
|     ): | ||||
|         super().__init__() | ||||
|         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d | ||||
|         activation = ComplexRelu() if complex_relu else nn.PReLU() | ||||
| 
 | ||||
|         self.encoder = nn.Sequential( | ||||
|             ComplexConv2d( | ||||
|                 in_channels, | ||||
|                 out_channel, | ||||
|                 kernel_size=kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|             ), | ||||
|             batchnorm(out_channel), | ||||
|             activation, | ||||
|         ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|         return self.encoder(waveform) | ||||
| 
 | ||||
| 
 | ||||
| class DCCRN_DECODER(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         in_channels: int, | ||||
|         out_channels: int, | ||||
|         kernel_size: Tuple[int, int], | ||||
|         complex_norm: bool = True, | ||||
|         complex_relu: bool = True, | ||||
|         stride: Tuple[int, int] = (2, 1), | ||||
|         padding: Tuple[int, int] = (2, 0), | ||||
|         output_padding: Tuple[int, int] = (1, 0), | ||||
|     ): | ||||
|         super().__init__() | ||||
|         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d | ||||
|         activation = ComplexRelu() if complex_relu else nn.PReLU() | ||||
| 
 | ||||
|         self.decoder = nn.Sequential( | ||||
|             ComplexConvTranspose2d( | ||||
|                 in_channels, | ||||
|                 out_channels, | ||||
|                 kernel_size=kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 output_padding=output_padding, | ||||
|             ), | ||||
|             batchnorm(out_channels), | ||||
|             activation, | ||||
|         ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|         return self.decoder(waveform) | ||||
| 
 | ||||
| 
 | ||||
| class DCCRN(Model): | ||||
| 
 | ||||
|     STFT_DEFAULTS = { | ||||
|         "window_len": 400, | ||||
|         "hop_size": 100, | ||||
|         "nfft": 512, | ||||
|         "window": "hamming", | ||||
|     } | ||||
| 
 | ||||
|     ED_DEFAULTS = { | ||||
|         "initial_output_channels": 32, | ||||
|         "depth": 6, | ||||
|         "kernel_size": 5, | ||||
|         "growth_factor": 2, | ||||
|         "stride": 2, | ||||
|         "padding": 2, | ||||
|         "output_padding": 1, | ||||
|     } | ||||
| 
 | ||||
|     LSTM_DEFAULTS = { | ||||
|         "num_layers": 2, | ||||
|         "hidden_size": 256, | ||||
|     } | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         stft: Optional[dict] = None, | ||||
|         encoder_decoder: Optional[dict] = None, | ||||
|         lstm: Optional[dict] = None, | ||||
|         complex_lstm: bool = True, | ||||
|         complex_norm: bool = True, | ||||
|         complex_relu: bool = True, | ||||
|         masking_mode: str = "E", | ||||
|         num_channels: int = 1, | ||||
|         sampling_rate=16000, | ||||
|         lr: float = 1e-3, | ||||
|         dataset: Optional[EnhancerDataset] = None, | ||||
|         duration: Optional[float] = None, | ||||
|         loss: Union[str, List, Any] = "mse", | ||||
|         metric: Union[str, List] = "mse", | ||||
|     ): | ||||
|         duration = ( | ||||
|             dataset.duration if isinstance(dataset, EnhancerDataset) else None | ||||
|         ) | ||||
|         if dataset is not None: | ||||
|             if sampling_rate != dataset.sampling_rate: | ||||
|                 logging.warning( | ||||
|                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" | ||||
|                 ) | ||||
|                 sampling_rate = dataset.sampling_rate | ||||
|         super().__init__( | ||||
|             num_channels=num_channels, | ||||
|             sampling_rate=sampling_rate, | ||||
|             lr=lr, | ||||
|             dataset=dataset, | ||||
|             duration=duration, | ||||
|             loss=loss, | ||||
|             metric=metric, | ||||
|         ) | ||||
| 
 | ||||
|         encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) | ||||
|         lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||||
|         stft = merge_dict(self.STFT_DEFAULTS, stft) | ||||
|         self.save_hyperparameters( | ||||
|             "encoder_decoder", | ||||
|             "lstm", | ||||
|             "stft", | ||||
|             "complex_lstm", | ||||
|             "complex_norm", | ||||
|             "masking_mode", | ||||
|         ) | ||||
|         self.complex_lstm = complex_lstm | ||||
|         self.complex_norm = complex_norm | ||||
|         self.masking_mode = masking_mode | ||||
| 
 | ||||
|         self.stft = ConvSTFT( | ||||
|             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] | ||||
|         ) | ||||
|         self.istft = ConviSTFT( | ||||
|             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] | ||||
|         ) | ||||
| 
 | ||||
|         self.encoder = nn.ModuleList() | ||||
|         self.decoder = nn.ModuleList() | ||||
| 
 | ||||
|         num_channels *= 2 | ||||
|         hidden_size = encoder_decoder["initial_output_channels"] | ||||
|         growth_factor = 2 | ||||
| 
 | ||||
|         for layer in range(encoder_decoder["depth"]): | ||||
| 
 | ||||
|             encoder_ = DCCRN_ENCODER( | ||||
|                 num_channels, | ||||
|                 hidden_size, | ||||
|                 kernel_size=(encoder_decoder["kernel_size"], 2), | ||||
|                 stride=(encoder_decoder["stride"], 1), | ||||
|                 padding=(encoder_decoder["padding"], 1), | ||||
|                 complex_norm=complex_norm, | ||||
|                 complex_relu=complex_relu, | ||||
|             ) | ||||
|             self.encoder.append(encoder_) | ||||
| 
 | ||||
|             decoder_ = DCCRN_DECODER( | ||||
|                 hidden_size + hidden_size, | ||||
|                 num_channels, | ||||
|                 kernel_size=(encoder_decoder["kernel_size"], 2), | ||||
|                 stride=(encoder_decoder["stride"], 1), | ||||
|                 padding=(encoder_decoder["padding"], 0), | ||||
|                 output_padding=(encoder_decoder["output_padding"], 0), | ||||
|                 complex_norm=complex_norm, | ||||
|                 complex_relu=complex_relu, | ||||
|             ) | ||||
| 
 | ||||
|             self.decoder.insert(0, decoder_) | ||||
| 
 | ||||
|             if layer < encoder_decoder["depth"] - 3: | ||||
|                 num_channels = hidden_size | ||||
|                 hidden_size *= growth_factor | ||||
|             else: | ||||
|                 num_channels = hidden_size | ||||
| 
 | ||||
|         kernel_size = hidden_size / 2 | ||||
|         hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) | ||||
| 
 | ||||
|         if self.complex_lstm: | ||||
|             lstms = [] | ||||
|             for layer in range(lstm["num_layers"]): | ||||
| 
 | ||||
|                 if layer == 0: | ||||
|                     input_size = int(hidden_size * kernel_size) | ||||
|                 else: | ||||
|                     input_size = lstm["hidden_size"] | ||||
| 
 | ||||
|                 if layer == lstm["num_layers"] - 1: | ||||
|                     projection_size = int(hidden_size * kernel_size) | ||||
|                 else: | ||||
|                     projection_size = None | ||||
| 
 | ||||
|                 kwargs = { | ||||
|                     "input_size": input_size, | ||||
|                     "hidden_size": lstm["hidden_size"], | ||||
|                     "num_layers": 1, | ||||
|                 } | ||||
| 
 | ||||
|                 lstms.append( | ||||
|                     ComplexLSTM(projection_size=projection_size, **kwargs) | ||||
|                 ) | ||||
|             self.lstm = nn.Sequential(*lstms) | ||||
|         else: | ||||
|             self.lstm = nn.LSTM( | ||||
|                 input_size=hidden_size * kernel_size, | ||||
|                 hidden_sizs=lstm["hidden_size"], | ||||
|                 num_layers=lstm["num_layers"], | ||||
|                 dropout=0.0, | ||||
|                 batch_first=False, | ||||
|             ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|         return waveform | ||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786