normalize input
This commit is contained in:
parent
e118c31f18
commit
415ed8e3d0
|
|
@ -133,10 +133,12 @@ class Demucs(Model):
|
|||
num_channels: int = 1,
|
||||
resample: int = 4,
|
||||
sampling_rate=16000,
|
||||
normalize=True,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[EnhancerDataset] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
floor=1e-3,
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||
|
|
@ -161,6 +163,8 @@ class Demucs(Model):
|
|||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||
hidden = encoder_decoder["initial_output_channels"]
|
||||
self.normalize = normalize
|
||||
self.floor = floor
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
|
|
@ -204,7 +208,10 @@ class Demucs(Model):
|
|||
raise TypeError(
|
||||
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||
)
|
||||
|
||||
if self.normalize:
|
||||
waveform = waveform.mean(dim=1, keepdim=True)
|
||||
std = waveform.std(dim=-1, keepdim=True)
|
||||
waveform = waveform / (self.floor + std)
|
||||
length = waveform.shape[-1]
|
||||
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||
if self.hparams.resample > 1:
|
||||
|
|
|
|||
Loading…
Reference in New Issue