diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 92d30ae..c679669 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -37,7 +37,7 @@ class Model(pl.LightningModule): Enhancer dataset used for training/validation duration: float, optional duration used for training/inference - loss : string or List of strings, default to "mse" + loss : string or List of strings or custom loss (nn.Module), default to "mse" loss functions to be used. Available ("mse","mae","Si-SDR") """