From 415ed8e3d0733bccf1d94b25100dfb03ba4fc4ad Mon Sep 17 00:00:00 2001 From: shahules786 Date: Tue, 18 Oct 2022 15:22:34 +0530 Subject: [PATCH] normalize input --- enhancer/models/demucs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 95d6a6f..86afb6c 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -133,10 +133,12 @@ class Demucs(Model): num_channels: int = 1, resample: int = 4, sampling_rate=16000, + normalize=True, lr: float = 1e-3, dataset: Optional[EnhancerDataset] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", + floor=1e-3, ): duration = ( dataset.duration if isinstance(dataset, EnhancerDataset) else None @@ -161,6 +163,8 @@ class Demucs(Model): lstm = merge_dict(self.LSTM_DEFAULTS, lstm) self.save_hyperparameters("encoder_decoder", "lstm", "resample") hidden = encoder_decoder["initial_output_channels"] + self.normalize = normalize + self.floor = floor self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() @@ -204,7 +208,10 @@ class Demucs(Model): raise TypeError( f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels" ) - + if self.normalize: + waveform = waveform.mean(dim=1, keepdim=True) + std = waveform.std(dim=-1, keepdim=True) + waveform = waveform / (self.floor + std) length = waveform.shape[-1] x = F.pad(waveform, (0, self.get_padding_length(length) - length)) if self.hparams.resample > 1: