pass total frames

This commit is contained in:
shahules786 2022-09-26 17:09:11 +05:30
parent 9890249824
commit e9ea0d1695
1 changed files with 1 additions and 1 deletions

View File

@ -199,7 +199,7 @@ class Model(pl.LightningModule):
window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
waveform = Inference.aggreagate(batch_prediction,window_size,step_size)
waveform = Inference.aggreagate(batch_prediction,window_size,waveform.shape[-1],step_size,)
if save_output and isinstance(audio,(str,Path)):
Inference.write_output(waveform,audio,model_sampling_rate)