commit
						9936bbc3e9
					
				|  | @ -9,17 +9,19 @@ from enhancer.data.dataset import EnhancerDataset | ||||||
| from enhancer.utils.io import Audio as audio | from enhancer.utils.io import Audio as audio | ||||||
| from enhancer.utils.utils import merge_dict | from enhancer.utils.utils import merge_dict | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class DemucsLSTM(nn.Module): | class DemucsLSTM(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         input_size: int, |         input_size: int, | ||||||
|         hidden_size: int, |         hidden_size: int, | ||||||
|         num_layers: int, |         num_layers: int, | ||||||
|         bidirectional:bool=True |         bidirectional: bool = True, | ||||||
| 
 |  | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional) |         self.lstm = nn.LSTM( | ||||||
|  |             input_size, hidden_size, num_layers, bidirectional=bidirectional | ||||||
|  |         ) | ||||||
|         dim = 2 if bidirectional else 1 |         dim = 2 if bidirectional else 1 | ||||||
|         self.linear = nn.Linear(dim * hidden_size, hidden_size) |         self.linear = nn.Linear(dim * hidden_size, hidden_size) | ||||||
| 
 | 
 | ||||||
|  | @ -32,7 +34,6 @@ class DemucsLSTM(nn.Module): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class DemucsEncoder(nn.Module): | class DemucsEncoder(nn.Module): | ||||||
| 
 |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         num_channels: int, |         num_channels: int, | ||||||
|  | @ -48,15 +49,15 @@ class DemucsEncoder(nn.Module): | ||||||
|             nn.Conv1d(num_channels, hidden_size, kernel_size, stride), |             nn.Conv1d(num_channels, hidden_size, kernel_size, stride), | ||||||
|             nn.ReLU(), |             nn.ReLU(), | ||||||
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), |             nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), | ||||||
|             activation |             activation, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def forward(self, waveform): |     def forward(self, waveform): | ||||||
| 
 | 
 | ||||||
|         return self.encoder(waveform) |         return self.encoder(waveform) | ||||||
| 
 | 
 | ||||||
| class DemucsDecoder(nn.Module): |  | ||||||
| 
 | 
 | ||||||
|  | class DemucsDecoder(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         num_channels: int, |         num_channels: int, | ||||||
|  | @ -64,7 +65,7 @@ class DemucsDecoder(nn.Module): | ||||||
|         kernel_size: int, |         kernel_size: int, | ||||||
|         stride: int = 1, |         stride: int = 1, | ||||||
|         glu: bool = False, |         glu: bool = False, | ||||||
|         layer:int=0 |         layer: int = 0, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         activation = nn.GLU(1) if glu else nn.ReLU() |         activation = nn.GLU(1) if glu else nn.ReLU() | ||||||
|  | @ -72,18 +73,44 @@ class DemucsDecoder(nn.Module): | ||||||
|         self.decoder = nn.Sequential( |         self.decoder = nn.Sequential( | ||||||
|             nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), |             nn.Conv1d(hidden_size, hidden_size * multi_factor, kernel_size, 1), | ||||||
|             activation, |             activation, | ||||||
|             nn.ConvTranspose1d(hidden_size,num_channels,kernel_size,stride) |             nn.ConvTranspose1d(hidden_size, num_channels, kernel_size, stride), | ||||||
|         ) |         ) | ||||||
|         if layer > 0: |         if layer > 0: | ||||||
|             self.decoder.add_module("4", nn.ReLU()) |             self.decoder.add_module("4", nn.ReLU()) | ||||||
| 
 | 
 | ||||||
|     def forward(self,waveform,): |     def forward( | ||||||
|  |         self, | ||||||
|  |         waveform, | ||||||
|  |     ): | ||||||
| 
 | 
 | ||||||
|         out = self.decoder(waveform) |         out = self.decoder(waveform) | ||||||
|         return out |         return out | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Demucs(Model): | class Demucs(Model): | ||||||
|  |     """ | ||||||
|  |     Demucs model from https://arxiv.org/pdf/1911.13254.pdf | ||||||
|  |     parameters: | ||||||
|  |         encoder_decoder: dict, optional | ||||||
|  |             keyword arguments passsed to encoder decoder block | ||||||
|  |         lstm : dict, optional | ||||||
|  |             keyword arguments passsed to LSTM block | ||||||
|  |         num_channels: int, defaults to 1 | ||||||
|  |             number channels in input audio | ||||||
|  |         sampling_rate: int, defaults to 16KHz | ||||||
|  |             sampling rate of input audio | ||||||
|  |         lr : float, defaults to 1e-3 | ||||||
|  |             learning rate used for training | ||||||
|  |         dataset: EnhancerDataset, optional | ||||||
|  |             EnhancerDataset object containing train/validation data for training | ||||||
|  |         duration : float, optional | ||||||
|  |             chunk duration in seconds | ||||||
|  |         loss : string or List of strings | ||||||
|  |             loss function to be used, available ("mse","mae","SI-SDR") | ||||||
|  |         metric : string or List of strings | ||||||
|  |             metric function to be used, available ("mse","mae","SI-SDR") | ||||||
|  | 
 | ||||||
|  |     """ | ||||||
| 
 | 
 | ||||||
|     ED_DEFAULTS = { |     ED_DEFAULTS = { | ||||||
|         "initial_output_channels": 48, |         "initial_output_channels": 48, | ||||||
|  | @ -108,18 +135,26 @@ class Demucs(Model): | ||||||
|         lr: float = 1e-3, |         lr: float = 1e-3, | ||||||
|         dataset: Optional[EnhancerDataset] = None, |         dataset: Optional[EnhancerDataset] = None, | ||||||
|         loss: Union[str, List] = "mse", |         loss: Union[str, List] = "mse", | ||||||
|         metric:Union[str, List] = "mse" |         metric: Union[str, List] = "mse", | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     ): |     ): | ||||||
|         duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None |         duration = ( | ||||||
|  |             dataset.duration if isinstance(dataset, EnhancerDataset) else None | ||||||
|  |         ) | ||||||
|         if dataset is not None: |         if dataset is not None: | ||||||
|             if sampling_rate != dataset.sampling_rate: |             if sampling_rate != dataset.sampling_rate: | ||||||
|                 logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") |                 logging.warn( | ||||||
|  |                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" | ||||||
|  |                 ) | ||||||
|                 sampling_rate = dataset.sampling_rate |                 sampling_rate = dataset.sampling_rate | ||||||
|         super().__init__(num_channels=num_channels, |         super().__init__( | ||||||
|                             sampling_rate=sampling_rate,lr=lr, |             num_channels=num_channels, | ||||||
|                             dataset=dataset,duration=duration,loss=loss, metric=metric) |             sampling_rate=sampling_rate, | ||||||
|  |             lr=lr, | ||||||
|  |             dataset=dataset, | ||||||
|  |             duration=duration, | ||||||
|  |             loss=loss, | ||||||
|  |             metric=metric, | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) |         encoder_decoder = merge_dict(self.ED_DEFAULTS, encoder_decoder) | ||||||
|         lstm = merge_dict(self.LSTM_DEFAULTS, lstm) |         lstm = merge_dict(self.LSTM_DEFAULTS, lstm) | ||||||
|  | @ -130,7 +165,8 @@ class Demucs(Model): | ||||||
| 
 | 
 | ||||||
|         for layer in range(encoder_decoder["depth"]): |         for layer in range(encoder_decoder["depth"]): | ||||||
| 
 | 
 | ||||||
|             encoder_layer = DemucsEncoder(num_channels=num_channels, |             encoder_layer = DemucsEncoder( | ||||||
|  |                 num_channels=num_channels, | ||||||
|                 hidden_size=hidden, |                 hidden_size=hidden, | ||||||
|                 kernel_size=encoder_decoder["kernel_size"], |                 kernel_size=encoder_decoder["kernel_size"], | ||||||
|                 stride=encoder_decoder["stride"], |                 stride=encoder_decoder["stride"], | ||||||
|  | @ -138,22 +174,24 @@ class Demucs(Model): | ||||||
|             ) |             ) | ||||||
|             self.encoder.append(encoder_layer) |             self.encoder.append(encoder_layer) | ||||||
| 
 | 
 | ||||||
|             decoder_layer = DemucsDecoder(num_channels=num_channels, |             decoder_layer = DemucsDecoder( | ||||||
|  |                 num_channels=num_channels, | ||||||
|                 hidden_size=hidden, |                 hidden_size=hidden, | ||||||
|                 kernel_size=encoder_decoder["kernel_size"], |                 kernel_size=encoder_decoder["kernel_size"], | ||||||
|                 stride=1, |                 stride=1, | ||||||
|                 glu=encoder_decoder["glu"], |                 glu=encoder_decoder["glu"], | ||||||
|                             layer=layer |                 layer=layer, | ||||||
|             ) |             ) | ||||||
|             self.decoder.insert(0, decoder_layer) |             self.decoder.insert(0, decoder_layer) | ||||||
| 
 | 
 | ||||||
|             num_channels = hidden |             num_channels = hidden | ||||||
|             hidden = self.ED_DEFAULTS["growth_factor"] * hidden |             hidden = self.ED_DEFAULTS["growth_factor"] * hidden | ||||||
| 
 | 
 | ||||||
|         self.de_lstm = DemucsLSTM(input_size=num_channels, |         self.de_lstm = DemucsLSTM( | ||||||
|  |             input_size=num_channels, | ||||||
|             hidden_size=num_channels, |             hidden_size=num_channels, | ||||||
|             num_layers=lstm["num_layers"], |             num_layers=lstm["num_layers"], | ||||||
|                         bidirectional=lstm["bidirectional"] |             bidirectional=lstm["bidirectional"], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def forward(self, waveform): |     def forward(self, waveform): | ||||||
|  | @ -162,13 +200,20 @@ class Demucs(Model): | ||||||
|             waveform = waveform.unsqueeze(1) |             waveform = waveform.unsqueeze(1) | ||||||
| 
 | 
 | ||||||
|         if waveform.size(1) != 1: |         if waveform.size(1) != 1: | ||||||
|             raise TypeError(f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels") |             raise TypeError( | ||||||
|  |                 f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         length = waveform.shape[-1] |         length = waveform.shape[-1] | ||||||
|         x = F.pad(waveform, (0, self.get_padding_length(length) - length)) |         x = F.pad(waveform, (0, self.get_padding_length(length) - length)) | ||||||
|         if self.hparams.resample > 1: |         if self.hparams.resample > 1: | ||||||
|             x = audio.resample_audio(audio=x, sr=self.hparams.sampling_rate, |             x = audio.resample_audio( | ||||||
|                         target_sr=int(self.hparams.sampling_rate * self.hparams.resample)) |                 audio=x, | ||||||
|  |                 sr=self.hparams.sampling_rate, | ||||||
|  |                 target_sr=int( | ||||||
|  |                     self.hparams.sampling_rate * self.hparams.resample | ||||||
|  |                 ), | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         encoder_outputs = [] |         encoder_outputs = [] | ||||||
|         for encoder in self.encoder: |         for encoder in self.encoder: | ||||||
|  | @ -184,8 +229,11 @@ class Demucs(Model): | ||||||
|             x = decoder(x) |             x = decoder(x) | ||||||
| 
 | 
 | ||||||
|         if self.hparams.resample > 1: |         if self.hparams.resample > 1: | ||||||
|             x = audio.resample_audio(x,int(self.hparams.sampling_rate * self.hparams.resample), |             x = audio.resample_audio( | ||||||
|                                     self.hparams.sampling_rate) |                 x, | ||||||
|  |                 int(self.hparams.sampling_rate * self.hparams.resample), | ||||||
|  |                 self.hparams.sampling_rate, | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         return x |         return x | ||||||
| 
 | 
 | ||||||
|  | @ -193,25 +241,23 @@ class Demucs(Model): | ||||||
| 
 | 
 | ||||||
|         input_length = math.ceil(input_length * self.hparams.resample) |         input_length = math.ceil(input_length * self.hparams.resample) | ||||||
| 
 | 
 | ||||||
|    |         for layer in range( | ||||||
|         for layer in range(self.hparams.encoder_decoder["depth"]):                                        # encoder operation |             self.hparams.encoder_decoder["depth"] | ||||||
|             input_length = math.ceil((input_length - self.hparams.encoder_decoder["kernel_size"])/self.hparams.encoder_decoder["stride"])+1 |         ):  # encoder operation | ||||||
|  |             input_length = ( | ||||||
|  |                 math.ceil( | ||||||
|  |                     (input_length - self.hparams.encoder_decoder["kernel_size"]) | ||||||
|  |                     / self.hparams.encoder_decoder["stride"] | ||||||
|  |                 ) | ||||||
|  |                 + 1 | ||||||
|  |             ) | ||||||
|             input_length = max(1, input_length) |             input_length = max(1, input_length) | ||||||
|         for layer in range(self.hparams.encoder_decoder["depth"]):                                        # decoder operaration |         for layer in range( | ||||||
|             input_length = (input_length-1) * self.hparams.encoder_decoder["stride"] + self.hparams.encoder_decoder["kernel_size"] |             self.hparams.encoder_decoder["depth"] | ||||||
|  |         ):  # decoder operaration | ||||||
|  |             input_length = (input_length - 1) * self.hparams.encoder_decoder[ | ||||||
|  |                 "stride" | ||||||
|  |             ] + self.hparams.encoder_decoder["kernel_size"] | ||||||
|         input_length = math.ceil(input_length / self.hparams.resample) |         input_length = math.ceil(input_length / self.hparams.resample) | ||||||
| 
 | 
 | ||||||
|         return int(input_length) |         return int(input_length) | ||||||
| 
 |  | ||||||
|          |  | ||||||
| 
 |  | ||||||
|          |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|          |  | ||||||
|  | @ -7,8 +7,8 @@ from typing import Optional, Union, List | ||||||
| from enhancer.models.model import Model | from enhancer.models.model import Model | ||||||
| from enhancer.data.dataset import EnhancerDataset | from enhancer.data.dataset import EnhancerDataset | ||||||
| 
 | 
 | ||||||
| class WavenetDecoder(nn.Module): |  | ||||||
| 
 | 
 | ||||||
|  | class WavenetDecoder(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         in_channels: int, |         in_channels: int, | ||||||
|  | @ -20,17 +20,24 @@ class WavenetDecoder(nn.Module): | ||||||
|     ): |     ): | ||||||
|         super(WavenetDecoder, self).__init__() |         super(WavenetDecoder, self).__init__() | ||||||
|         self.decoder = nn.Sequential( |         self.decoder = nn.Sequential( | ||||||
|             nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), |             nn.Conv1d( | ||||||
|  |                 in_channels, | ||||||
|  |                 out_channels, | ||||||
|  |                 kernel_size, | ||||||
|  |                 stride=stride, | ||||||
|  |                 padding=padding, | ||||||
|  |                 dilation=dilation, | ||||||
|  |             ), | ||||||
|             nn.BatchNorm1d(out_channels), |             nn.BatchNorm1d(out_channels), | ||||||
|             nn.LeakyReLU(negative_slope=0.1) |             nn.LeakyReLU(negative_slope=0.1), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def forward(self, waveform): |     def forward(self, waveform): | ||||||
| 
 | 
 | ||||||
|         return self.decoder(waveform) |         return self.decoder(waveform) | ||||||
| 
 | 
 | ||||||
| class WavenetEncoder(nn.Module): |  | ||||||
| 
 | 
 | ||||||
|  | class WavenetEncoder(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         in_channels: int, |         in_channels: int, | ||||||
|  | @ -42,20 +49,45 @@ class WavenetEncoder(nn.Module): | ||||||
|     ): |     ): | ||||||
|         super(WavenetEncoder, self).__init__() |         super(WavenetEncoder, self).__init__() | ||||||
|         self.encoder = nn.Sequential( |         self.encoder = nn.Sequential( | ||||||
|             nn.Conv1d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,dilation=dilation), |             nn.Conv1d( | ||||||
|  |                 in_channels, | ||||||
|  |                 out_channels, | ||||||
|  |                 kernel_size, | ||||||
|  |                 stride=stride, | ||||||
|  |                 padding=padding, | ||||||
|  |                 dilation=dilation, | ||||||
|  |             ), | ||||||
|             nn.BatchNorm1d(out_channels), |             nn.BatchNorm1d(out_channels), | ||||||
|             nn.LeakyReLU(negative_slope=0.1)   |             nn.LeakyReLU(negative_slope=0.1), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 |     def forward(self, waveform): | ||||||
|     def forward( |  | ||||||
|         self, |  | ||||||
|         waveform |  | ||||||
|     ): |  | ||||||
|         return self.encoder(waveform) |         return self.encoder(waveform) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class WaveUnet(Model): | class WaveUnet(Model): | ||||||
|  |     """ | ||||||
|  |     Wave-U-Net model from  https://arxiv.org/pdf/1811.11307.pdf | ||||||
|  |     parameters: | ||||||
|  |         num_channels: int, defaults to 1 | ||||||
|  |             number of channels in input audio | ||||||
|  |         depth : int, defaults to 12 | ||||||
|  |             depth of network | ||||||
|  |         initial_output_channels: int, defaults to 24 | ||||||
|  |             number of output channels in intial upsampling layer | ||||||
|  |         sampling_rate: int, defaults to 16KHz | ||||||
|  |             sampling rate of input audio | ||||||
|  |         lr : float, defaults to 1e-3 | ||||||
|  |             learning rate used for training | ||||||
|  |         dataset: EnhancerDataset, optional | ||||||
|  |             EnhancerDataset object containing train/validation data for training | ||||||
|  |         duration : float, optional | ||||||
|  |             chunk duration in seconds | ||||||
|  |         loss : string or List of strings | ||||||
|  |             loss function to be used, available ("mse","mae","SI-SDR") | ||||||
|  |         metric : string or List of strings | ||||||
|  |             metric function to be used, available ("mse","mae","SI-SDR") | ||||||
|  |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|  | @ -67,16 +99,25 @@ class WaveUnet(Model): | ||||||
|         dataset: Optional[EnhancerDataset] = None, |         dataset: Optional[EnhancerDataset] = None, | ||||||
|         duration: Optional[float] = None, |         duration: Optional[float] = None, | ||||||
|         loss: Union[str, List] = "mse", |         loss: Union[str, List] = "mse", | ||||||
|         metric:Union[str,List] = "mse" |         metric: Union[str, List] = "mse", | ||||||
|     ): |     ): | ||||||
|         duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None |         duration = ( | ||||||
|  |             dataset.duration if isinstance(dataset, EnhancerDataset) else None | ||||||
|  |         ) | ||||||
|         if dataset is not None: |         if dataset is not None: | ||||||
|             if sampling_rate != dataset.sampling_rate: |             if sampling_rate != dataset.sampling_rate: | ||||||
|                 logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") |                 logging.warn( | ||||||
|  |                     f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" | ||||||
|  |                 ) | ||||||
|                 sampling_rate = dataset.sampling_rate |                 sampling_rate = dataset.sampling_rate | ||||||
|         super().__init__(num_channels=num_channels, |         super().__init__( | ||||||
|                             sampling_rate=sampling_rate,lr=lr, |             num_channels=num_channels, | ||||||
|                             dataset=dataset,duration=duration,loss=loss, metric=metric |             sampling_rate=sampling_rate, | ||||||
|  |             lr=lr, | ||||||
|  |             dataset=dataset, | ||||||
|  |             duration=duration, | ||||||
|  |             loss=loss, | ||||||
|  |             metric=metric, | ||||||
|         ) |         ) | ||||||
|         self.save_hyperparameters("depth") |         self.save_hyperparameters("depth") | ||||||
|         self.encoders = nn.ModuleList() |         self.encoders = nn.ModuleList() | ||||||
|  | @ -90,33 +131,35 @@ class WaveUnet(Model): | ||||||
|             num_channels = out_channels |             num_channels = out_channels | ||||||
|             out_channels += initial_output_channels |             out_channels += initial_output_channels | ||||||
|             if layer == depth - 1: |             if layer == depth - 1: | ||||||
|                 decoder = WavenetDecoder(depth * initial_output_channels + num_channels,num_channels) |                 decoder = WavenetDecoder( | ||||||
|  |                     depth * initial_output_channels + num_channels, num_channels | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 decoder = WavenetDecoder(num_channels+out_channels,num_channels) |                 decoder = WavenetDecoder( | ||||||
|  |                     num_channels + out_channels, num_channels | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|             self.decoders.insert(0, decoder) |             self.decoders.insert(0, decoder) | ||||||
| 
 | 
 | ||||||
|         bottleneck_dim = depth * initial_output_channels |         bottleneck_dim = depth * initial_output_channels | ||||||
|         self.bottleneck = nn.Sequential( |         self.bottleneck = nn.Sequential( | ||||||
|             nn.Conv1d(bottleneck_dim,bottleneck_dim, 15, stride=1, |             nn.Conv1d(bottleneck_dim, bottleneck_dim, 15, stride=1, padding=7), | ||||||
|                       padding=7), |  | ||||||
|             nn.BatchNorm1d(bottleneck_dim), |             nn.BatchNorm1d(bottleneck_dim), | ||||||
|             nn.LeakyReLU(negative_slope=0.1, inplace=True) |             nn.LeakyReLU(negative_slope=0.1, inplace=True), | ||||||
|         ) |         ) | ||||||
|         self.final = nn.Sequential( |         self.final = nn.Sequential( | ||||||
|             nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), |             nn.Conv1d(1 + initial_output_channels, 1, kernel_size=1, stride=1), | ||||||
|             nn.Tanh() |             nn.Tanh(), | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 |     def forward(self, waveform): | ||||||
|     def forward( |  | ||||||
|         self,waveform |  | ||||||
|     ): |  | ||||||
|         if waveform.dim() == 2: |         if waveform.dim() == 2: | ||||||
|             waveform = waveform.unsqueeze(1) |             waveform = waveform.unsqueeze(1) | ||||||
| 
 | 
 | ||||||
|         if waveform.size(1) != 1: |         if waveform.size(1) != 1: | ||||||
|             raise TypeError(f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels") |             raise TypeError( | ||||||
|  |                 f"Wave-U-Net can only process mono channel audio, input has {waveform.size(1)} channels" | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         encoder_outputs = [] |         encoder_outputs = [] | ||||||
|         out = waveform |         out = waveform | ||||||
|  | @ -139,10 +182,12 @@ class WaveUnet(Model): | ||||||
| 
 | 
 | ||||||
|     def fix_last_dim(self, x, target): |     def fix_last_dim(self, x, target): | ||||||
|         """ |         """ | ||||||
|         trying to do centre crop along last dimension |         centre crop along last dimension | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         assert x.shape[-1] >= target.shape[-1], "input dimension cannot be larger than target dimension" |         assert ( | ||||||
|  |             x.shape[-1] >= target.shape[-1] | ||||||
|  |         ), "input dimension cannot be larger than target dimension" | ||||||
|         if x.shape[-1] == target.shape[-1]: |         if x.shape[-1] == target.shape[-1]: | ||||||
|             return x |             return x | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES