fix duration
This commit is contained in:
parent
003bab91f9
commit
2bfca78caa
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -140,11 +140,11 @@ class DCCRN(Mayamodel):
|
|||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -136,16 +136,17 @@ class Demucs(Mayamodel):
|
|||
normalize=True,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
floor=1e-3,
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ CACHE_DIR = os.getenv(
|
|||
)
|
||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
SAVE_NAME = "enhancer"
|
||||
SAVE_NAME = "mayavoz"
|
||||
|
||||
|
||||
class Mayamodel(pl.LightningModule):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -103,11 +103,11 @@ class WaveUnet(Mayamodel):
|
|||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
Loading…
Reference in New Issue