pass total frames
This commit is contained in:
parent
9890249824
commit
e9ea0d1695
|
|
@ -199,7 +199,7 @@ class Model(pl.LightningModule):
|
|||
window_size = round(duration * model_sampling_rate)
|
||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
||||
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
|
||||
waveform = Inference.aggreagate(batch_prediction,window_size,step_size)
|
||||
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
|
||||
|
||||
if save_output and isinstance(audio,(str,Path)):
|
||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||
|
|
|
|||
Loading…
Reference in New Issue