refactor arguments
This commit is contained in:
		
							parent
							
								
									4edc90deb0
								
							
						
					
					
						commit
						79311444d8
					
				|  | @ -1,8 +1,12 @@ | ||||||
|  | from typing import Optional | ||||||
| from torch import nn | from torch import nn | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| import math  | import math  | ||||||
| 
 | 
 | ||||||
|  | from enhancer.models.model import Model | ||||||
|  | from enhancer.data.dataset import EnhancerDataset | ||||||
| from enhancer.utils.io import Audio as audio | from enhancer.utils.io import Audio as audio | ||||||
|  | from enhancer.utils.utils import merge_dict | ||||||
| 
 | 
 | ||||||
| class DeLSTM(nn.Module): | class DeLSTM(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|  | @ -25,64 +29,68 @@ class DeLSTM(nn.Module): | ||||||
| 
 | 
 | ||||||
|         return output,(h,c) |         return output,(h,c) | ||||||
| 
 | 
 | ||||||
| class Demucs(nn.Module): | class Demucs(Model): | ||||||
|  | 
 | ||||||
|  |     ED_DEFAULTS = { | ||||||
|  |         "intial_output_channels":48, | ||||||
|  |         "kernel_size":8, | ||||||
|  |         "stride":1, | ||||||
|  |         "depth":5, | ||||||
|  |         "glu":True, | ||||||
|  |         "growth_factor":2, | ||||||
|  |     } | ||||||
|  |     LSTM_DEFAULTS = { | ||||||
|  |         "bidirectional":True, | ||||||
|  |         "num_layers":2, | ||||||
|  |     } | ||||||
|      |      | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         c_in:int=1, |         encoder_decoder:Optional[dict]=None, | ||||||
|         c_out:int=1, |         lstm:Optional[dict]=None, | ||||||
|         hidden:int=48, |         num_channels:int=1, | ||||||
|         kernel_size:int=8, |  | ||||||
|         stride:int=4, |  | ||||||
|         growth_factor:int=2, |  | ||||||
|         depth:int = 5, |  | ||||||
|         glu:bool = True, |  | ||||||
|         bidirectional:bool=True, |  | ||||||
|         resample:int=4, |         resample:int=4, | ||||||
|         sampling_rate = 16000 |         sampling_rate = 16000, | ||||||
|  |         dataset:Optional[EnhancerDataset]=None, | ||||||
| 
 | 
 | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__(num_channels=num_channels, | ||||||
|         self.c_in = c_in  |                             sampling_rate=sampling_rate,dataset=dataset) | ||||||
|         self.c_out = c_out  |          | ||||||
|         self.hidden = hidden |         encoder_decoder = merge_dict(self.ED_DEFAULTS,encoder_decoder) | ||||||
|         self.growth_factor = growth_factor |         lstm = merge_dict(self.LSTM_DEFAULTS,lstm) | ||||||
|         self.stride = stride |         self.save_hyperparameters("encoder_decoder","lstm","resample") | ||||||
|         self.kernel_size = kernel_size |          | ||||||
|         self.depth = depth |         hidden = encoder_decoder["initial_channel_output"] | ||||||
|         self.bidirectional = bidirectional |         activation = nn.GLU(1) if encoder_decoder["glu"] else nn.ReLU() | ||||||
|         self.activation = nn.GLU(1) if glu else nn.ReLU() |         multi_factor = 2 if encoder_decoder["glu"] else 1 | ||||||
|         self.resample = resample |  | ||||||
|         self.sampling_rate = sampling_rate |  | ||||||
|         multi_factor = 2 if glu else 1 |  | ||||||
| 
 | 
 | ||||||
|         self.encoder = nn.ModuleList() |         self.encoder = nn.ModuleList() | ||||||
|         self.decoder = nn.ModuleList() |         self.decoder = nn.ModuleList() | ||||||
| 
 | 
 | ||||||
|         for layer in range(self.depth): |         for layer in range(encoder_decoder["depth"]): | ||||||
| 
 | 
 | ||||||
|             encoder_layer = [nn.Conv1d(c_in,hidden,kernel_size,stride), |             encoder_layer = [nn.Conv1d(num_channels,hidden,encoder_decoder["kernel_size"],encoder_decoder["stride"]), | ||||||
|                             nn.ReLU(), |                             nn.ReLU(), | ||||||
|                             nn.Conv1d(hidden, hidden*multi_factor,kernel_size,1), |                             nn.Conv1d(hidden, hidden*multi_factor,encoder_decoder["kernel_size"],1), | ||||||
|                             self.activation] |                             activation] | ||||||
|             encoder_layer = nn.Sequential(*encoder_layer) |             encoder_layer = nn.Sequential(*encoder_layer) | ||||||
|             self.encoder.append(encoder_layer) |             self.encoder.append(encoder_layer) | ||||||
| 
 | 
 | ||||||
|             decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,kernel_size,1), |             decoder_layer = [nn.Conv1d(hidden,hidden*multi_factor,encoder_decoder["kernel_size"],1), | ||||||
|                             self.activation, |                             activation, | ||||||
|                             nn.ConvTranspose1d(hidden,c_out,kernel_size,stride) |                             nn.ConvTranspose1d(hidden,num_channels,encoder_decoder["kernel_size"],encoder_decoder["stride"]) | ||||||
|                             ] |                             ] | ||||||
|             if layer>0: |             if layer>0: | ||||||
|                 decoder_layer.append(nn.ReLU()) |                 decoder_layer.append(nn.ReLU()) | ||||||
|             decoder_layer = nn.Sequential(*decoder_layer) |             decoder_layer = nn.Sequential(*decoder_layer) | ||||||
|             self.decoder.insert(0,decoder_layer) |             self.decoder.insert(0,decoder_layer) | ||||||
| 
 | 
 | ||||||
|             c_out = hidden |             num_channels = hidden | ||||||
|             c_in = hidden |  | ||||||
|             hidden = self.growth_factor * hidden |             hidden = self.growth_factor * hidden | ||||||
| 
 | 
 | ||||||
|          |          | ||||||
|         self.de_lstm = DeLSTM(input_size=c_in,hidden_size=c_in,num_layers=2,bidirectional=self.bidirectional) |         self.de_lstm = DeLSTM(input_size=num_channels,hidden_size=num_channels,num_layers=lstm["num_layers"],bidirectional=lstm["bidirectional"]) | ||||||
| 
 | 
 | ||||||
|     def forward(self,mixed_signal): |     def forward(self,mixed_signal): | ||||||
| 
 | 
 | ||||||
|  | @ -91,14 +99,13 @@ class Demucs(nn.Module): | ||||||
| 
 | 
 | ||||||
|         length = mixed_signal.shape[-1] |         length = mixed_signal.shape[-1] | ||||||
|         x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))  |         x = F.pad(mixed_signal, (0,self.get_padding_length(length) - length))  | ||||||
|         if self.resample>1: |         if self.hparams.resample>1: | ||||||
|             x = audio.pt_resample_audio(audio=x, sr=self.sampling_rate, |             x = audio.pt_resample_audio(audio=x, sr=self.hparams.sampling_rate, | ||||||
|                         target_sr=int(self.sampling_rate * self.resample)) |                         target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) | ||||||
|          |          | ||||||
|         encoder_outputs = [] |         encoder_outputs = [] | ||||||
|         for encoder in self.encoder: |         for encoder in self.encoder: | ||||||
|             x = encoder(x) |             x = encoder(x) | ||||||
|             print(x.shape) |  | ||||||
|             encoder_outputs.append(x) |             encoder_outputs.append(x) | ||||||
|         x = x.permute(0,2,1) |         x = x.permute(0,2,1) | ||||||
|         x,_ = self.de_lstm(x) |         x,_ = self.de_lstm(x) | ||||||
|  | @ -109,23 +116,23 @@ class Demucs(nn.Module): | ||||||
|             x += skip_connection[..., :x.shape[-1]] |             x += skip_connection[..., :x.shape[-1]] | ||||||
|             x = decoder(x) |             x = decoder(x) | ||||||
|          |          | ||||||
|         if self.resample > 1: |         if self.hparams.resample > 1: | ||||||
|             x = audio.pt_resample_audio(x,int(self.sampling_rate * self.resample), |             x = audio.pt_resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), | ||||||
|                                     self.sampling_rate) |                                     self.hparams.sampling_rate) | ||||||
| 
 | 
 | ||||||
|         return x |         return x | ||||||
|          |          | ||||||
|     def get_padding_length(self,input_length): |     def get_padding_length(self,input_length): | ||||||
| 
 | 
 | ||||||
|         input_length = math.ceil(input_length * self.resample) |         input_length = math.ceil(input_length * self.hparams.resample) | ||||||
| 
 | 
 | ||||||
|    |    | ||||||
|         for layer in range(self.depth):                                        # encoder operation |         for layer in range(self.hparams.encoder_decoder["depth"]):                                        # encoder operation | ||||||
|             input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1 |             input_length = math.ceil((input_length - self.kernel_size)/self.stride)+1 | ||||||
|             input_length = max(1,input_length) |             input_length = max(1,input_length) | ||||||
|         for layer in range(self.depth):                                        # decoder operaration |         for layer in range(self.hparams.encoder_decoder["depth"]):                                        # decoder operaration | ||||||
|             input_length = (input_length-1) * self.stride + self.kernel_size |             input_length = (input_length-1) * self.stride + self.kernel_size | ||||||
|         input_length = math.ceil(input_length/self.resample) |         input_length = math.ceil(input_length/self.hparams.resample) | ||||||
| 
 | 
 | ||||||
|         return int(input_length) |         return int(input_length) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786