normalize input
This commit is contained in:
parent
e118c31f18
commit
415ed8e3d0
|
|
@ -133,10 +133,12 @@ class Demucs(Model):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
resample: int = 4,
|
resample: int = 4,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
|
normalize=True,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
|
floor=1e-3,
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
|
@ -161,6 +163,8 @@ class Demucs(Model):
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||||
hidden = encoder_decoder["initial_output_channels"]
|
hidden = encoder_decoder["initial_output_channels"]
|
||||||
|
self.normalize = normalize
|
||||||
|
self.floor = floor
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
|
@ -204,7 +208,10 @@ class Demucs(Model):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
)
|
)
|
||||||
|
if self.normalize:
|
||||||
|
waveform = waveform.mean(dim=1, keepdim=True)
|
||||||
|
std = waveform.std(dim=-1, keepdim=True)
|
||||||
|
waveform = waveform / (self.floor + std)
|
||||||
length = waveform.shape[-1]
|
length = waveform.shape[-1]
|
||||||
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||||
if self.hparams.resample > 1:
|
if self.hparams.resample > 1:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue