diff --git a/enhancer/inference.py b/enhancer/inference.py index 2c63d54..fd2f57a 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -84,10 +84,15 @@ class Inference: if isinstance(filename,str): filename = Path(filename) + + parent, name = filename.parent, "cleaned_"+filename.name + filename = parent/Path(name) if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: - wavfile.write(filename,rate=sr,data=waveform.detach().cpu()) + if isinstance(waveform,torch.Tensor): + waveform = waveform.detach().cpu().squeeze().numpy() + wavfile.write(filename,rate=sr,data=waveform) @staticmethod def prepare_output(waveform:torch.Tensor, model_sampling_rate:int, diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 5827301..de2edab 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -249,7 +249,7 @@ class Model(pl.LightningModule): else: waveform = Inference.prepare_output(waveform, model_sampling_rate, audio, sampling_rate) - return waveform + return waveform @property def valid_monitor(self):