Merge pull request #24 from shahules786/dev
Minor improvements/bug fixes
This commit is contained in:
commit
da85de13ad
|
|
@ -1,4 +1,8 @@
|
|||
#local
|
||||
cleaned_my_voice.wav
|
||||
lightning_logs/
|
||||
my_voice.wav
|
||||
pretrained/
|
||||
*.ckpt
|
||||
*_local.yaml
|
||||
cli/train_config/dataset/Vctk_local.yaml
|
||||
|
|
|
|||
|
|
@ -1 +1,2 @@
|
|||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
from mayavoz.models import Mayamodel
|
||||
|
|
|
|||
|
|
@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
|||
|
||||
|
||||
@hydra.main(config_path="train_config", config_name="config")
|
||||
def main(config: DictConfig):
|
||||
def train(config: DictConfig):
|
||||
|
||||
OmegaConf.save(config, "config_log.yaml")
|
||||
OmegaConf.save(config, "config.yaml")
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(
|
||||
|
|
@ -96,7 +96,7 @@ def main(config: DictConfig):
|
|||
trainer.test(model)
|
||||
|
||||
logger.experiment.log_artifact(
|
||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
||||
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
||||
)
|
||||
|
||||
saved_location = os.path.join(
|
||||
|
|
@ -117,4 +117,4 @@ def main(config: DictConfig):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
train()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
_target_: mayavoz.data.dataset.MayaDataset
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
name : dns-2020
|
||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||
duration : 2.0
|
||||
sampling_rate: 16000
|
||||
batch_size: 32
|
||||
valid_size: 0.05
|
||||
min_valid_minutes: 15
|
||||
files:
|
||||
train_clean : CleanSpeech_training
|
||||
test_clean : CleanSpeech_training
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -80,6 +82,21 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self._validation = []
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
|
||||
if (
|
||||
num_workers > 0
|
||||
and sys.platform == "darwin"
|
||||
and sys.version_info[0] >= 3
|
||||
and sys.version_info[1] >= 8
|
||||
):
|
||||
warnings.warn(
|
||||
"num_workers > 0 is not supported with macOS and Python 3.8+: "
|
||||
"setting num_workers = 0."
|
||||
)
|
||||
num_workers = 0
|
||||
|
||||
self.num_workers = num_workers
|
||||
if min_valid_minutes > 0.0:
|
||||
self.min_valid_minutes = min_valid_minutes
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ class Fileprocessor:
|
|||
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||
|
||||
if matching_function is None:
|
||||
if name.lower() == "vctk":
|
||||
if name.lower() in ("vctk", "valentini"):
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||
elif name.lower() == "dns-2020":
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
@ -134,7 +134,7 @@ class Pesq:
|
|||
try:
|
||||
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||
except Exception as e:
|
||||
logging.warning(f"{e} error occured while calculating PESQ")
|
||||
warnings.warn(f"{e} error occured while calculating PESQ")
|
||||
return torch.tensor(np.mean(pesq_values))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -140,11 +140,11 @@ class DCCRN(Mayamodel):
|
|||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -136,16 +136,17 @@ class Demucs(Mayamodel):
|
|||
normalize=True,
|
||||
lr: float = 1e-3,
|
||||
dataset: Optional[MayaDataset] = None,
|
||||
duration: Optional[float] = None,
|
||||
loss: Union[str, List] = "mse",
|
||||
metric: Union[str, List] = "mse",
|
||||
floor=1e-3,
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ CACHE_DIR = os.getenv(
|
|||
)
|
||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||
DEFAULT_DEVICE = "cpu"
|
||||
SAVE_NAME = "enhancer"
|
||||
SAVE_NAME = "mayavoz"
|
||||
|
||||
|
||||
class Mayamodel(pl.LightningModule):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -103,11 +103,11 @@ class WaveUnet(Mayamodel):
|
|||
metric: Union[str, List] = "mse",
|
||||
):
|
||||
duration = (
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
||||
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warning(
|
||||
warnings.warn(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(
|
||||
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(
|
||||
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
|||
test_clean="clean_testset_wav",
|
||||
test_noisy="noisy_testset_wav",
|
||||
)
|
||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
||||
dataset = MayaDataset(
|
||||
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue