diff --git a/enhancer/models/model.py b/enhancer/models/model.py index c679669..9f285d3 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -18,8 +18,11 @@ from enhancer.inference import Inference from enhancer.loss import LOSS_MAP, LossWrapper from enhancer.version import __version__ -CACHE_DIR = "" -HF_TORCH_WEIGHTS = "" +CACHE_DIR = os.getenv( + "ENHANCER_CACHE", + os.path.expanduser("~/.cache/torch/enhancer"), +) +HF_TORCH_WEIGHTS = "pytorch_model.ckpt" DEFAULT_DEVICE = "cpu"