init dccrn
This commit is contained in:
		
							parent
							
								
									d3e052c5f3
								
							
						
					
					
						commit
						981763207a
					
				|  | @ -0,0 +1,240 @@ | ||||||
|  | import logging | ||||||
|  | from typing import Any, List, Optional, Tuple, Union | ||||||
|  | 
 | ||||||
|  | from torch import nn | ||||||
|  | 
 | ||||||
|  | from enhancer.data import EnhancerDataset | ||||||
|  | from enhancer.models import Model | ||||||
|  | from enhancer.models.complexnn import ComplexConv2d, ComplexLSTM | ||||||
|  | from enhancer.models.complexnn.conv import ComplexConvTranspose2d | ||||||
|  | from enhancer.models.complexnn.utils import ComplexBatchNorm2D, ComplexRelu | ||||||
|  | from enhancer.utils.transforms import ConviSTFT, ConvSTFT | ||||||
|  | from enhancer.utils.utils import merge_dict | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DCCRN_ENCODER(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channels: int, | ||||||
|  |         out_channel: int, | ||||||
|  |         kernel_size: Tuple[int, int], | ||||||
|  |         complex_norm: bool = True, | ||||||
|  |         complex_relu: bool = True, | ||||||
|  |         stride: Tuple[int, int] = (2, 1), | ||||||
|  |         padding: Tuple[int, int] = (2, 1), | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d | ||||||
|  |         activation = ComplexRelu() if complex_relu else nn.PReLU() | ||||||
|  | 
 | ||||||
|  |         self.encoder = nn.Sequential( | ||||||
|  |             ComplexConv2d( | ||||||
|  |                 in_channels, | ||||||
|  |                 out_channel, | ||||||
|  |                 kernel_size=kernel_size, | ||||||
|  |                 stride=stride, | ||||||
|  |                 padding=padding, | ||||||
|  |             ), | ||||||
|  |             batchnorm(out_channel), | ||||||
|  |             activation, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def forward(self, waveform): | ||||||
|  | 
 | ||||||
|  |         return self.encoder(waveform) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DCCRN_DECODER(nn.Module): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         in_channels: int, | ||||||
|  |         out_channels: int, | ||||||
|  |         kernel_size: Tuple[int, int], | ||||||
|  |         complex_norm: bool = True, | ||||||
|  |         complex_relu: bool = True, | ||||||
|  |         stride: Tuple[int, int] = (2, 1), | ||||||
|  |         padding: Tuple[int, int] = (2, 0), | ||||||
|  |         output_padding: Tuple[int, int] = (1, 0), | ||||||
|  |     ): | ||||||
|  |         super().__init__() | ||||||
|  |         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d | ||||||
|  |         activation = ComplexRelu() if complex_relu else nn.PReLU() | ||||||
|  | 
 | ||||||
|  |         self.decoder = nn.Sequential( | ||||||
|  |             ComplexConvTranspose2d( | ||||||
|  |                 in_channels, | ||||||
|  |                 out_channels, | ||||||
|  |                 kernel_size=kernel_size, | ||||||
|  |                 stride=stride, | ||||||
|  |                 padding=padding, | ||||||
|  |                 output_padding=output_padding, | ||||||
|  |             ), | ||||||
|  |             batchnorm(out_channels), | ||||||
|  |             activation, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     def forward(self, waveform): | ||||||
|  | 
 | ||||||
|  |         return self.decoder(waveform) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DCCRN(Model): | ||||||
|  | 
 | ||||||
|  |     STFT_DEFAULTS = { | ||||||
|  |         "window_len": 400, | ||||||
|  |         "hop_size": 100, | ||||||
|  |         "nfft": 512, | ||||||
|  |         "window": "hamming", | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     ED_DEFAULTS = { | ||||||
|  |         "initial_output_channels": 32, | ||||||
|  |         "depth": 6, | ||||||
|  |         "kernel_size": 5, | ||||||
|  |         "growth_factor": 2, | ||||||
|  |         "stride": 2, | ||||||
|  |         "padding": 2, | ||||||
|  |         "output_padding": 1, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     LSTM_DEFAULTS = { | ||||||
|  |         "num_layers": 2, | ||||||
|  |         "hidden_size": 256, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         stft: Optional[dict] = None, | ||||||
|  |         encoder_decoder: Optional[dict] = None, | ||||||
|  |         lstm: Optional[dict] = None, | ||||||
|  |         complex_lstm: bool = True, | ||||||
|  |         complex_norm: bool = True, | ||||||
|  |         complex_relu: bool = True, | ||||||
|  |         masking_mode: str = "E", | ||||||
|  |         num_channels: int = 1, | ||||||
|  |         sampling_rate=16000, | ||||||
|  |         lr: float = 1e-3, | ||||||
|  |         dataset: Optional[EnhancerDataset] = None, | ||||||
|  |         duration: Optional[float] = None, | ||||||
|  |         loss: Union[str, List, Any] = "mse", | ||||||
|  |         metric: Union[str, List] = "mse", | ||||||
|  |     ): | ||||||
|  |         duration = ( | ||||||
|  |             dataset.duration if isinstance(dataset, EnhancerDataset) else None | ||||||
|  |         ) | ||||||
|  |         if dataset is not None: | ||||||
|  |             if sampling_rate != dataset.sampling_rate: | ||||||
|  |                 logging.warning( | ||||||
|  |                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" | ||||||
|  |                 ) | ||||||
|  |                 sampling_rate = dataset.sampling_rate | ||||||
|  |         super().__init__( | ||||||
|  |             num_channels=num_channels, | ||||||
|  |             sampling_rate=sampling_rate, | ||||||
|  |             lr=lr, | ||||||
|  |             dataset=dataset, | ||||||
|  |             duration=duration, | ||||||
|  |             loss=loss, | ||||||
|  |             metric=metric, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) | ||||||
|  |         lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||||||
|  |         stft = merge_dict(self.STFT_DEFAULTS, stft) | ||||||
|  |         self.save_hyperparameters( | ||||||
|  |             "encoder_decoder", | ||||||
|  |             "lstm", | ||||||
|  |             "stft", | ||||||
|  |             "complex_lstm", | ||||||
|  |             "complex_norm", | ||||||
|  |             "masking_mode", | ||||||
|  |         ) | ||||||
|  |         self.complex_lstm = complex_lstm | ||||||
|  |         self.complex_norm = complex_norm | ||||||
|  |         self.masking_mode = masking_mode | ||||||
|  | 
 | ||||||
|  |         self.stft = ConvSTFT( | ||||||
|  |             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] | ||||||
|  |         ) | ||||||
|  |         self.istft = ConviSTFT( | ||||||
|  |             stft["window_len"], stft["hop_size"], stft["nfft"], stft["window"] | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         self.encoder = nn.ModuleList() | ||||||
|  |         self.decoder = nn.ModuleList() | ||||||
|  | 
 | ||||||
|  |         num_channels *= 2 | ||||||
|  |         hidden_size = encoder_decoder["initial_output_channels"] | ||||||
|  |         growth_factor = 2 | ||||||
|  | 
 | ||||||
|  |         for layer in range(encoder_decoder["depth"]): | ||||||
|  | 
 | ||||||
|  |             encoder_ = DCCRN_ENCODER( | ||||||
|  |                 num_channels, | ||||||
|  |                 hidden_size, | ||||||
|  |                 kernel_size=(encoder_decoder["kernel_size"], 2), | ||||||
|  |                 stride=(encoder_decoder["stride"], 1), | ||||||
|  |                 padding=(encoder_decoder["padding"], 1), | ||||||
|  |                 complex_norm=complex_norm, | ||||||
|  |                 complex_relu=complex_relu, | ||||||
|  |             ) | ||||||
|  |             self.encoder.append(encoder_) | ||||||
|  | 
 | ||||||
|  |             decoder_ = DCCRN_DECODER( | ||||||
|  |                 hidden_size + hidden_size, | ||||||
|  |                 num_channels, | ||||||
|  |                 kernel_size=(encoder_decoder["kernel_size"], 2), | ||||||
|  |                 stride=(encoder_decoder["stride"], 1), | ||||||
|  |                 padding=(encoder_decoder["padding"], 0), | ||||||
|  |                 output_padding=(encoder_decoder["output_padding"], 0), | ||||||
|  |                 complex_norm=complex_norm, | ||||||
|  |                 complex_relu=complex_relu, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |             self.decoder.insert(0, decoder_) | ||||||
|  | 
 | ||||||
|  |             if layer < encoder_decoder["depth"] - 3: | ||||||
|  |                 num_channels = hidden_size | ||||||
|  |                 hidden_size *= growth_factor | ||||||
|  |             else: | ||||||
|  |                 num_channels = hidden_size | ||||||
|  | 
 | ||||||
|  |         kernel_size = hidden_size / 2 | ||||||
|  |         hidden_size = stft["nfft"] / 2 ** (encoder_decoder["depth"]) | ||||||
|  | 
 | ||||||
|  |         if self.complex_lstm: | ||||||
|  |             lstms = [] | ||||||
|  |             for layer in range(lstm["num_layers"]): | ||||||
|  | 
 | ||||||
|  |                 if layer == 0: | ||||||
|  |                     input_size = int(hidden_size * kernel_size) | ||||||
|  |                 else: | ||||||
|  |                     input_size = lstm["hidden_size"] | ||||||
|  | 
 | ||||||
|  |                 if layer == lstm["num_layers"] - 1: | ||||||
|  |                     projection_size = int(hidden_size * kernel_size) | ||||||
|  |                 else: | ||||||
|  |                     projection_size = None | ||||||
|  | 
 | ||||||
|  |                 kwargs = { | ||||||
|  |                     "input_size": input_size, | ||||||
|  |                     "hidden_size": lstm["hidden_size"], | ||||||
|  |                     "num_layers": 1, | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 lstms.append( | ||||||
|  |                     ComplexLSTM(projection_size=projection_size, **kwargs) | ||||||
|  |                 ) | ||||||
|  |             self.lstm = nn.Sequential(*lstms) | ||||||
|  |         else: | ||||||
|  |             self.lstm = nn.LSTM( | ||||||
|  |                 input_size=hidden_size * kernel_size, | ||||||
|  |                 hidden_sizs=lstm["hidden_size"], | ||||||
|  |                 num_layers=lstm["num_layers"], | ||||||
|  |                 dropout=0.0, | ||||||
|  |                 batch_first=False, | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |     def forward(self, waveform): | ||||||
|  | 
 | ||||||
|  |         return waveform | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786