diff --git a/mayavoz/models/dccrn.py b/mayavoz/models/dccrn.py index 6b8646c..638aefe 100644 --- a/mayavoz/models/dccrn.py +++ b/mayavoz/models/dccrn.py @@ -1,4 +1,4 @@ -import logging +import warnings from typing import Any, List, Optional, Tuple, Union import torch @@ -140,11 +140,11 @@ class DCCRN(Mayamodel): metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/mayavoz/models/demucs.py b/mayavoz/models/demucs.py index 8424f17..dbe584b 100644 --- a/mayavoz/models/demucs.py +++ b/mayavoz/models/demucs.py @@ -1,5 +1,5 @@ -import logging import math +import warnings from typing import List, Optional, Union import torch.nn.functional as F @@ -136,16 +136,17 @@ class Demucs(Mayamodel): normalize=True, lr: float = 1e-3, dataset: Optional[MayaDataset] = None, + duration: Optional[float] = None, loss: Union[str, List] = "mse", metric: Union[str, List] = "mse", floor=1e-3, ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/mayavoz/models/model.py b/mayavoz/models/model.py index e248b2c..5143d0b 100644 --- a/mayavoz/models/model.py +++ b/mayavoz/models/model.py @@ -24,7 +24,7 @@ CACHE_DIR = os.getenv( ) HF_TORCH_WEIGHTS = "pytorch_model.ckpt" DEFAULT_DEVICE = "cpu" -SAVE_NAME = "enhancer" +SAVE_NAME = "mayavoz" class Mayamodel(pl.LightningModule): diff --git a/mayavoz/models/waveunet.py b/mayavoz/models/waveunet.py index c9acfda..0e2ec80 100644 --- a/mayavoz/models/waveunet.py +++ b/mayavoz/models/waveunet.py @@ -1,4 +1,4 @@ -import logging +import warnings from typing import List, Optional, Union import torch @@ -103,11 +103,11 @@ class WaveUnet(Mayamodel): metric: Union[str, List] = "mse", ): duration = ( - dataset.duration if isinstance(dataset, MayaDataset) else None + dataset.duration if isinstance(dataset, MayaDataset) else duration ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warning( + warnings.warn( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate