normalize input

This commit is contained in:
shahules786 2022-10-18 15:22:34 +05:30
parent e118c31f18
commit 415ed8e3d0
1 changed files with 8 additions and 1 deletions

View File

@ -133,10 +133,12 @@ class Demucs(Model):
num_channels: int = 1,
resample: int = 4,
sampling_rate=16000,
normalize=True,
lr: float = 1e-3,
dataset: Optional[EnhancerDataset] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
floor=1e-3,
):
duration = (
dataset.duration if isinstance(dataset, EnhancerDataset) else None
@ -161,6 +163,8 @@ class Demucs(Model):
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
hidden = encoder_decoder["initial_output_channels"]
self.normalize = normalize
self.floor = floor
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
@ -204,7 +208,10 @@ class Demucs(Model):
raise TypeError(
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
)
if self.normalize:
waveform = waveform.mean(dim=1, keepdim=True)
std = waveform.std(dim=-1, keepdim=True)
waveform = waveform / (self.floor + std)
length = waveform.shape[-1]
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
if self.hparams.resample > 1: