Merge pull request #24 from shahules786/dev

Minor improvements/bug fixes
This commit is contained in:
Shahul ES 2022-11-15 22:03:45 +05:30 committed by GitHub
commit da85de13ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 311 additions and 86 deletions

4
.gitignore vendored
View File

@ -1,4 +1,8 @@
#local
cleaned_my_voice.wav
lightning_logs/
my_voice.wav
pretrained/
*.ckpt
*_local.yaml
cli/train_config/dataset/Vctk_local.yaml

View File

@ -1 +1,2 @@
__import__("pkg_resources").declare_namespace(__name__)
from mayavoz.models import Mayamodel

View File

@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
def train(config: DictConfig):
OmegaConf.save(config, "config_log.yaml")
OmegaConf.save(config, "config.yaml")
callbacks = []
logger = MLFlowLogger(
@ -96,7 +96,7 @@ def main(config: DictConfig):
trainer.test(model)
logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml"
logger.run_id, f"{trainer.default_root_dir}/config.yaml"
)
saved_location = os.path.join(
@ -117,4 +117,4 @@ def main(config: DictConfig):
if __name__ == "__main__":
main()
train()

View File

@ -1,10 +1,10 @@
_target_: mayavoz.data.dataset.MayaDataset
root_dir : /Users/shahules/Myprojects/MS-SNSD
name : dns-2020
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 2.0
sampling_rate: 16000
batch_size: 32
valid_size: 0.05
min_valid_minutes: 15
files:
train_clean : CleanSpeech_training
test_clean : CleanSpeech_training

View File

@ -1,6 +1,8 @@
import math
import multiprocessing
import os
import sys
import warnings
from pathlib import Path
from typing import Optional
@ -80,6 +82,21 @@ class TaskDataset(pl.LightningDataModule):
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
if (
num_workers > 0
and sys.platform == "darwin"
and sys.version_info[0] >= 3
and sys.version_info[1] >= 8
):
warnings.warn(
"num_workers > 0 is not supported with macOS and Python 3.8+: "
"setting num_workers = 0."
)
num_workers = 0
self.num_workers = num_workers
if min_valid_minutes > 0.0:
self.min_valid_minutes = min_valid_minutes

View File

@ -93,7 +93,7 @@ class Fileprocessor:
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None:
if name.lower() == "vctk":
if name.lower() in ("vctk", "valentini"):
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)

View File

@ -1,4 +1,4 @@
import logging
import warnings
import numpy as np
import torch
@ -134,7 +134,7 @@ class Pesq:
try:
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ")
warnings.warn(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values))

View File

@ -1,4 +1,4 @@
import logging
import warnings
from typing import Any, List, Optional, Tuple, Union
import torch
@ -140,11 +140,11 @@ class DCCRN(Mayamodel):
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None
dataset.duration if isinstance(dataset, MayaDataset) else duration
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate

View File

@ -1,5 +1,5 @@
import logging
import math
import warnings
from typing import List, Optional, Union
import torch.nn.functional as F
@ -136,16 +136,17 @@ class Demucs(Mayamodel):
normalize=True,
lr: float = 1e-3,
dataset: Optional[MayaDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse",
floor=1e-3,
):
duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None
dataset.duration if isinstance(dataset, MayaDataset) else duration
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate

View File

@ -24,7 +24,7 @@ CACHE_DIR = os.getenv(
)
HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
DEFAULT_DEVICE = "cpu"
SAVE_NAME = "enhancer"
SAVE_NAME = "mayavoz"
class Mayamodel(pl.LightningModule):

View File

@ -1,4 +1,4 @@
import logging
import warnings
from typing import List, Optional, Union
import torch
@ -103,11 +103,11 @@ class WaveUnet(Mayamodel):
metric: Union[str, List] = "mse",
):
duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None
dataset.duration if isinstance(dataset, MayaDataset) else duration
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warning(
warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate

File diff suppressed because one or more lines are too long

View File

@ -15,7 +15,9 @@ def vctk_dataset():
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
dataset = MayaDataset(
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset

View File

@ -15,7 +15,9 @@ def vctk_dataset():
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
dataset = MayaDataset(
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset

View File

@ -15,7 +15,9 @@ def vctk_dataset():
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files)
dataset = MayaDataset(
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset