Merge pull request #24 from shahules786/dev
Minor improvements/bug fixes
This commit is contained in:
commit
da85de13ad
|
|
@ -1,4 +1,8 @@
|
||||||
#local
|
#local
|
||||||
|
cleaned_my_voice.wav
|
||||||
|
lightning_logs/
|
||||||
|
my_voice.wav
|
||||||
|
pretrained/
|
||||||
*.ckpt
|
*.ckpt
|
||||||
*_local.yaml
|
*_local.yaml
|
||||||
cli/train_config/dataset/Vctk_local.yaml
|
cli/train_config/dataset/Vctk_local.yaml
|
||||||
|
|
|
||||||
|
|
@ -1 +1,2 @@
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
from mayavoz.models import Mayamodel
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config", config_name="config")
|
@hydra.main(config_path="train_config", config_name="config")
|
||||||
def main(config: DictConfig):
|
def train(config: DictConfig):
|
||||||
|
|
||||||
OmegaConf.save(config, "config_log.yaml")
|
OmegaConf.save(config, "config.yaml")
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
logger = MLFlowLogger(
|
logger = MLFlowLogger(
|
||||||
|
|
@ -96,7 +96,7 @@ def main(config: DictConfig):
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
|
|
||||||
logger.experiment.log_artifact(
|
logger.experiment.log_artifact(
|
||||||
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
|
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
saved_location = os.path.join(
|
saved_location = os.path.join(
|
||||||
|
|
@ -117,4 +117,4 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
train()
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
_target_: mayavoz.data.dataset.MayaDataset
|
_target_: mayavoz.data.dataset.MayaDataset
|
||||||
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
|
||||||
name : dns-2020
|
name : dns-2020
|
||||||
|
root_dir : /Users/shahules/Myprojects/MS-SNSD
|
||||||
duration : 2.0
|
duration : 2.0
|
||||||
sampling_rate: 16000
|
sampling_rate: 16000
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
valid_size: 0.05
|
min_valid_minutes: 15
|
||||||
files:
|
files:
|
||||||
train_clean : CleanSpeech_training
|
train_clean : CleanSpeech_training
|
||||||
test_clean : CleanSpeech_training
|
test_clean : CleanSpeech_training
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -80,6 +82,21 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self._validation = []
|
self._validation = []
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
|
if num_workers is None:
|
||||||
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
|
|
||||||
|
if (
|
||||||
|
num_workers > 0
|
||||||
|
and sys.platform == "darwin"
|
||||||
|
and sys.version_info[0] >= 3
|
||||||
|
and sys.version_info[1] >= 8
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"num_workers > 0 is not supported with macOS and Python 3.8+: "
|
||||||
|
"setting num_workers = 0."
|
||||||
|
)
|
||||||
|
num_workers = 0
|
||||||
|
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
if min_valid_minutes > 0.0:
|
if min_valid_minutes > 0.0:
|
||||||
self.min_valid_minutes = min_valid_minutes
|
self.min_valid_minutes = min_valid_minutes
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ class Fileprocessor:
|
||||||
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||||
|
|
||||||
if matching_function is None:
|
if matching_function is None:
|
||||||
if name.lower() == "vctk":
|
if name.lower() in ("vctk", "valentini"):
|
||||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||||
elif name.lower() == "dns-2020":
|
elif name.lower() == "dns-2020":
|
||||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import logging
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -134,7 +134,7 @@ class Pesq:
|
||||||
try:
|
try:
|
||||||
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"{e} error occured while calculating PESQ")
|
warnings.warn(f"{e} error occured while calculating PESQ")
|
||||||
return torch.tensor(np.mean(pesq_values))
|
return torch.tensor(np.mean(pesq_values))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import logging
|
import warnings
|
||||||
from typing import Any, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -140,11 +140,11 @@ class DCCRN(Mayamodel):
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warning(
|
warnings.warn(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -136,16 +136,17 @@ class Demucs(Mayamodel):
|
||||||
normalize=True,
|
normalize=True,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[MayaDataset] = None,
|
dataset: Optional[MayaDataset] = None,
|
||||||
|
duration: Optional[float] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
floor=1e-3,
|
floor=1e-3,
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warning(
|
warnings.warn(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ CACHE_DIR = os.getenv(
|
||||||
)
|
)
|
||||||
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
|
||||||
DEFAULT_DEVICE = "cpu"
|
DEFAULT_DEVICE = "cpu"
|
||||||
SAVE_NAME = "enhancer"
|
SAVE_NAME = "mayavoz"
|
||||||
|
|
||||||
|
|
||||||
class Mayamodel(pl.LightningModule):
|
class Mayamodel(pl.LightningModule):
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import logging
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -103,11 +103,11 @@ class WaveUnet(Mayamodel):
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, MayaDataset) else None
|
dataset.duration if isinstance(dataset, MayaDataset) else duration
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warning(
|
warnings.warn(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(
|
||||||
|
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||||
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(
|
||||||
|
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||||
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ def vctk_dataset():
|
||||||
test_clean="clean_testset_wav",
|
test_clean="clean_testset_wav",
|
||||||
test_noisy="noisy_testset_wav",
|
test_noisy="noisy_testset_wav",
|
||||||
)
|
)
|
||||||
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
|
dataset = MayaDataset(
|
||||||
|
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
|
||||||
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue