DCCRN implementation
This commit is contained in:
		
							parent
							
								
									fc33bd83b6
								
							
						
					
					
						commit
						511d2141d4
					
				|  | @ -1,6 +1,8 @@ | |||
| import logging | ||||
| from typing import Any, List, Optional, Tuple, Union | ||||
| 
 | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| from torch import nn | ||||
| 
 | ||||
| from enhancer.data import EnhancerDataset | ||||
|  | @ -12,6 +14,7 @@ from enhancer.models.complexnn import ( | |||
|     ComplexLSTM, | ||||
|     ComplexRelu, | ||||
| ) | ||||
| from enhancer.models.complexnn.utils import complex_cat | ||||
| from enhancer.utils.transforms import ConviSTFT, ConvSTFT | ||||
| from enhancer.utils.utils import merge_dict | ||||
| 
 | ||||
|  | @ -54,6 +57,7 @@ class DCCRN_DECODER(nn.Module): | |||
|         in_channels: int, | ||||
|         out_channels: int, | ||||
|         kernel_size: Tuple[int, int], | ||||
|         layer: int = 0, | ||||
|         complex_norm: bool = True, | ||||
|         complex_relu: bool = True, | ||||
|         stride: Tuple[int, int] = (2, 1), | ||||
|  | @ -64,18 +68,30 @@ class DCCRN_DECODER(nn.Module): | |||
|         batchnorm = ComplexBatchNorm2D if complex_norm else nn.BatchNorm2d | ||||
|         activation = ComplexRelu() if complex_relu else nn.PReLU() | ||||
| 
 | ||||
|         self.decoder = nn.Sequential( | ||||
|             ComplexConvTranspose2d( | ||||
|                 in_channels, | ||||
|                 out_channels, | ||||
|                 kernel_size=kernel_size, | ||||
|                 stride=stride, | ||||
|                 padding=padding, | ||||
|                 output_padding=output_padding, | ||||
|             ), | ||||
|             batchnorm(out_channels), | ||||
|             activation, | ||||
|         ) | ||||
|         if layer != 0: | ||||
|             self.decoder = nn.Sequential( | ||||
|                 ComplexConvTranspose2d( | ||||
|                     in_channels, | ||||
|                     out_channels, | ||||
|                     kernel_size=kernel_size, | ||||
|                     stride=stride, | ||||
|                     padding=padding, | ||||
|                     output_padding=output_padding, | ||||
|                 ), | ||||
|                 batchnorm(out_channels), | ||||
|                 activation, | ||||
|             ) | ||||
|         else: | ||||
|             self.decoder = nn.Sequential( | ||||
|                 ComplexConvTranspose2d( | ||||
|                     in_channels, | ||||
|                     out_channels, | ||||
|                     kernel_size=kernel_size, | ||||
|                     stride=stride, | ||||
|                     padding=padding, | ||||
|                     output_padding=output_padding, | ||||
|                 ) | ||||
|             ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|  | @ -187,6 +203,7 @@ class DCCRN(Model): | |||
|             decoder_ = DCCRN_DECODER( | ||||
|                 hidden_size + hidden_size, | ||||
|                 num_channels, | ||||
|                 layer=layer, | ||||
|                 kernel_size=(encoder_decoder["kernel_size"], 2), | ||||
|                 stride=(encoder_decoder["stride"], 1), | ||||
|                 padding=(encoder_decoder["padding"], 0), | ||||
|  | @ -231,14 +248,83 @@ class DCCRN(Model): | |||
|                 ) | ||||
|             self.lstm = nn.Sequential(*lstms) | ||||
|         else: | ||||
|             self.lstm = nn.LSTM( | ||||
|                 input_size=hidden_size * kernel_size, | ||||
|                 hidden_sizs=lstm["hidden_size"], | ||||
|                 num_layers=lstm["num_layers"], | ||||
|                 dropout=0.0, | ||||
|                 batch_first=False, | ||||
|             self.lstm = nn.Sequential( | ||||
|                 nn.LSTM( | ||||
|                     input_size=hidden_size * kernel_size, | ||||
|                     hidden_sizs=lstm["hidden_size"], | ||||
|                     num_layers=lstm["num_layers"], | ||||
|                     dropout=0.0, | ||||
|                     batch_first=False, | ||||
|                 )[0], | ||||
|                 nn.Linear(lstm["hidden"], hidden_size * kernel_size), | ||||
|             ) | ||||
| 
 | ||||
|     def forward(self, waveform): | ||||
| 
 | ||||
|         return waveform | ||||
|         waveform_stft = self.stft(waveform) | ||||
|         real = waveform_stft[:, : self.stft.nfft // 2 + 1] | ||||
|         imag = waveform_stft[:, self.stft.nfft // 2 + 1 :] | ||||
| 
 | ||||
|         mag_spec = torch.sqrt(real**2 + imag**2 + 1e-9) | ||||
|         phase_spec = torch.atan2(imag, real) | ||||
|         complex_spec = torch.stack([mag_spec, phase_spec], 1)[:, :, 1:] | ||||
| 
 | ||||
|         encoder_outputs = [] | ||||
|         out = complex_spec | ||||
|         for _, encoder in enumerate(self.encoder): | ||||
|             out = encoder(out) | ||||
|             encoder_outputs.append(out) | ||||
| 
 | ||||
|         B, C, D, T = out.size() | ||||
|         out = out.permute(3, 0, 1, 2) | ||||
|         if self.complex_lstm: | ||||
| 
 | ||||
|             lstm_real = out[:, :, : C // 2] | ||||
|             lstm_imag = out[:, :, C // 2 :] | ||||
|             lstm_real = lstm_real.reshape(T, B, C // 2 * D) | ||||
|             lstm_imag = lstm_imag.reshape(T, B, C // 2 * D) | ||||
|             lstm_real, lstm_imag = self.lstm([lstm_real, lstm_imag]) | ||||
|             lstm_real = lstm_real.reshape(T, B, C // 2, D) | ||||
|             lstm_imag = lstm_imag.reshape(T, B, C // 2, D) | ||||
|             out = torch.cat([lstm_real, lstm_imag], 2) | ||||
|         else: | ||||
|             out = out.reshape(T, B, C * D) | ||||
|             out = self.lstm(out) | ||||
|             out = out.reshape(T, B, D, C) | ||||
| 
 | ||||
|         out = out.permute(1, 2, 3, 0) | ||||
|         for layer, decoder in enumerate(self.decoder): | ||||
|             skip_connection = encoder_outputs.pop(-1) | ||||
|             out = complex_cat([skip_connection, out]) | ||||
|             out = decoder(out) | ||||
|             out = out[..., 1:] | ||||
|         mask_real, mask_imag = out[:, 0], out[:, 1] | ||||
|         mask_real = F.pad(mask_real, [0, 0, 1, 0]) | ||||
|         mask_imag = F.pad(mask_imag, [0, 0, 1, 0]) | ||||
|         if self.masking_mode == "E": | ||||
| 
 | ||||
|             mask_mag = torch.sqrt(mask_real**2 + mask_imag**2) | ||||
|             real_phase = mask_real / (mask_mag + 1e-8) | ||||
|             imag_phase = mask_imag / (mask_mag + 1e-8) | ||||
|             mask_phase = torch.atan2(imag_phase, real_phase) | ||||
|             mask_mag = torch.tanh(mask_mag) | ||||
|             est_mag = mask_mag * mag_spec | ||||
|             est_phase = mask_phase * phase_spec | ||||
|             # cos(theta) + isin(theta) | ||||
|             real = est_mag + torch.cos(est_phase) | ||||
|             imag = est_mag + torch.sin(est_phase) | ||||
| 
 | ||||
|         if self.masking_mode == "C": | ||||
| 
 | ||||
|             real = real * mask_real - imag * mask_imag | ||||
|             imag = real * mask_imag + imag * mask_real | ||||
| 
 | ||||
|         else: | ||||
| 
 | ||||
|             real = real * mask_real | ||||
|             imag = imag * mask_imag | ||||
| 
 | ||||
|         spec = torch.cat([real, imag], 1) | ||||
|         wav = self.istft(spec).squeeze(1) | ||||
|         wav = wav.clamp_(-1, 1) | ||||
|         return wav | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786