fix duration

This commit is contained in:
shahules786 2022-11-15 21:42:02 +05:30
parent 003bab91f9
commit 2bfca78caa
4 changed files with 11 additions and 10 deletions

View File

@ -1,4 +1,4 @@
import logging
import warnings
from typing import Any, List, Optional, Tuple, Union
import torch
@ -140,11 +140,11 @@ class DCCRN(Mayamodel):
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None
dataset.duration if isinstance(dataset, MayaDataset) else duration
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate

View File

@ -1,5 +1,5 @@
import logging
import math
import warnings
from typing import List, Optional, Union
import torch.nn.functional as F
@ -136,16 +136,17 @@ class Demucs(Mayamodel):
normalize=True,
lr: float = 1e-3,
dataset: Optional[MayaDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
floor=1e-3,
):
duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None
dataset.duration if isinstance(dataset, MayaDataset) else duration
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate

View File

@ -24,7 +24,7 @@ CACHE_DIR = os.getenv(
)
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
DEFAULT_DEVICE = "cpu"
SAVE_NAME = "enhancer"
SAVE_NAME = "mayavoz"
class Mayamodel(pl.LightningModule):

View File

@ -1,4 +1,4 @@
import logging
import warnings
from typing import List, Optional, Union
import torch
@ -103,11 +103,11 @@ class WaveUnet(Mayamodel):
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None
dataset.duration if isinstance(dataset, MayaDataset) else duration
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate