normalize input

This commit is contained in:
shahules786 2022-10-18 15:22:34 +05:30
parent e118c31f18
commit 415ed8e3d0
1 changed files with 8 additions and 1 deletions

View File

@ -133,10 +133,12 @@ class Demucs(Model):
num_channels: int = 1, num_channels: int = 1,
resample: int = 4, resample: int = 4,
sampling_rate=16000, sampling_rate=16000,
normalize=True,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None, dataset: Optional[EnhancerDataset] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
floor=1e-3,
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None dataset.duration if isinstance(dataset, EnhancerDataset) else None
@ -161,6 +163,8 @@ class Demucs(Model):
lstm = merge_dict(self.LSTM_DEFAULTS, lstm) lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder", "lstm", "resample") self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"] hidden = encoder_decoder["initial_output_channels"]
self.normalize = normalize
self.floor = floor
self.encoder = nn.ModuleList() self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList() self.decoder = nn.ModuleList()
@ -204,7 +208,10 @@ class Demucs(Model):
raise TypeError( raise TypeError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
) )
if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True)
std = waveform.std(dim=-1, keepdim=True)
waveform = waveform / (self.floor + std)
length = waveform.shape[-1] length = waveform.shape[-1]
x = F.pad(waveform, (0, self.get_padding_length(length) - length)) x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample > 1: if self.hparams.resample > 1: