rename as infer
This commit is contained in:
parent
04656487ab
commit
868fde7e69
|
|
@ -169,7 +169,7 @@ class Model(pl.LightningModule):
|
|||
|
||||
return model
|
||||
|
||||
def infer_batch(self,batch,batch_size):
|
||||
def infer(self,batch:torch.Tensor,batch_size:int=32):
|
||||
|
||||
assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
batch_predictions = []
|
||||
|
|
@ -198,7 +198,7 @@ class Model(pl.LightningModule):
|
|||
waveform.to(self.device)
|
||||
window_size = round(duration * model_sampling_rate)
|
||||
batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size)
|
||||
batch_prediction = self.infer_batch(batched_waveform,batch_size=batch_size)
|
||||
batch_prediction = self.infer(batched_waveform,batch_size=batch_size)
|
||||
waveform = Inference.aggreagate(batch_prediction,window_size,step_size)
|
||||
|
||||
if save_output and isinstance(audio,(str,Path)):
|
||||
|
|
|
|||
Loading…
Reference in New Issue