diff --git a/enhancer/inference.py b/enhancer/inference.py index 2c63d54..fd2f57a 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -84,10 +84,15 @@ class Inference: if isinstance(filename,str): filename = Path(filename) + + parent, name = filename.parent, "cleaned_"+filename.name + filename = parent/Path(name) if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: - wavfile.write(filename,rate=sr,data=waveform.detach().cpu()) + if isinstance(waveform,torch.Tensor): + waveform = waveform.detach().cpu().squeeze().numpy() + wavfile.write(filename,rate=sr,data=waveform) @staticmethod def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,