rename as infer

This commit is contained in:
shahules786 2022-09-26 12:38:54 +05:30
parent 04656487ab
commit 868fde7e69
1 changed files with 2 additions and 2 deletions

View File

@ -169,7 +169,7 @@ class Model(pl.LightningModule):
return model return model
def infer_batch(self,batch,batch_size): def infer(self,batch:torch.Tensor,batch_size:int=32):
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = [] batch_predictions = []
@ -198,7 +198,7 @@ class Model(pl.LightningModule):
waveform.to(self.device) waveform.to(self.device)
window_size = round(duration * model_sampling_rate) window_size = round(duration * model_sampling_rate)
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
batch_prediction = self.infer_batch(batched_waveform,batch_size=batch_size) batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
waveform = Inference.aggreagate(batch_prediction,window_size,step_size) waveform = Inference.aggreagate(batch_prediction,window_size,step_size)
if save_output and isinstance(audio,(str,Path)): if save_output and isinstance(audio,(str,Path)):