diff --git a/enhancer/inference.py b/enhancer/inference.py index fc60624..414f6ae 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -2,6 +2,7 @@ from json import load import wave import numpy as np from scipy.signal import get_window +from scipy.io import wavfile from typing import List, Optional, Union import torch import torch.nn.functional as F @@ -73,6 +74,19 @@ class Inference: return data + @staticmethod + def write_output(waveform:torch.Tensor,filename:Union[str,Path],sr:int): + + if isinstance(filename,str): + filename = Path(filename) + if filename.is_file(): + raise FileExistsError(f"file {filename} already exists") + else: + wavfile.write(filename,rate=sr,data=waveform.detach().cpu()) + + + +