diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 999dc6c..83355fd 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -169,7 +169,7 @@ class Model(pl.LightningModule): return model - def infer_batch(self,batch,batch_size): + def infer(self,batch:torch.Tensor,batch_size:int=32): assert batch.ndim == 3, f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" batch_predictions = [] @@ -198,7 +198,7 @@ class Model(pl.LightningModule): waveform.to(self.device) window_size = round(duration * model_sampling_rate) batched_waveform = Inference.batchify(waveform,window_size,step_size=step_size) - batch_prediction = self.infer_batch(batched_waveform,batch_size=batch_size) + batch_prediction = self.infer(batched_waveform,batch_size=batch_size) waveform = Inference.aggreagate(batch_prediction,window_size,step_size) if save_output and isinstance(audio,(str,Path)):