fix duration
This commit is contained in:
parent
003bab91f9
commit
2bfca78caa
|
|
@ -1,4 +1,4 @@
|
||||||
import logging
|
import warnings
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -140,11 +140,11 @@ class DCCRN(Mayamodel):
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warning(
|
warnings.warn(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -136,16 +136,17 @@ class Demucs(Mayamodel):
|
||||||
normalize=True,
|
normalize=True,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[MayaDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
floor=1e-3,
|
floor=1e-3,
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warning(
|
warnings.warn(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ CACHE_DIR = os.getenv(
|
||||||
)
|
)
|
||||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||||
DEFAULT_DEVICE = "cpu"
|
DEFAULT_DEVICE = "cpu"
|
||||||
SAVE_NAME = "enhancer"
|
SAVE_NAME = "mayavoz"
|
||||||
|
|
||||||
|
|
||||||
class Mayamodel(pl.LightningModule):
|
class Mayamodel(pl.LightningModule):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import logging
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -103,11 +103,11 @@ class WaveUnet(Mayamodel):
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warning(
|
warnings.warn(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue