From 868fde7e691ef70241dd0afec38192effc7578e0 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 26 Sep 2022 12:38:54 +0530 Subject: [PATCH] rename as infer --- enhancer/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 999dc6c..83355fd 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -169,7 +169,7 @@ class Model(pl.LightningModule): return model - def infer_batch(self,batch,batch_size): + def infer(self,batch:torch.Tensor,batch_size:int=32): assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] @@ -198,7 +198,7 @@ class Model(pl.LightningModule): waveform.to(self.device) window_size = round(duration * model_sampling_rate) batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) - batch_prediction = self.infer_batch(batched_waveform,batch_size=batch_size) + batch_prediction = self.infer(batched_waveform,batch_size=batch_size) waveform = Inference.aggreagate(batch_prediction,window_size,step_size) if save_output and isinstance(audio,(str,Path)):