Merge pull request #24 from shahules786/dev

Minor improvements/bug fixes
This commit is contained in:
Shahul ES 2022-11-15 22:03:45 +05:30 committed by GitHub
commit da85de13ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 311 additions and 86 deletions

4
.gitignore vendored
View File

@ -1,4 +1,8 @@
#local #local
cleaned_my_voice.wav
lightning_logs/
my_voice.wav
pretrained/
*.ckpt *.ckpt
*_local.yaml *_local.yaml
cli/train_config/dataset/Vctk_local.yaml cli/train_config/dataset/Vctk_local.yaml

View File

@ -1 +1,2 @@
__import__("pkg_resources").declare_namespace(__name__) __import__("pkg_resources").declare_namespace(__name__)
from mayavoz.models import Mayamodel

View File

@ -19,9 +19,9 @@ JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config") @hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig): def train(config: DictConfig):
OmegaConf.save(config, "config_log.yaml") OmegaConf.save(config, "config.yaml")
callbacks = [] callbacks = []
logger = MLFlowLogger( logger = MLFlowLogger(
@ -96,7 +96,7 @@ def main(config: DictConfig):
trainer.test(model) trainer.test(model)
logger.experiment.log_artifact( logger.experiment.log_artifact(
logger.run_id, f"{trainer.default_root_dir}/config_log.yaml" logger.run_id, f"{trainer.default_root_dir}/config.yaml"
) )
saved_location = os.path.join( saved_location = os.path.join(
@ -117,4 +117,4 @@ def main(config: DictConfig):
if __name__ == "__main__": if __name__ == "__main__":
main() train()

View File

@ -1,10 +1,10 @@
_target_: mayavoz.data.dataset.MayaDataset _target_: mayavoz.data.dataset.MayaDataset
root_dir : /Users/shahules/Myprojects/MS-SNSD
name : dns-2020 name : dns-2020
root_dir : /Users/shahules/Myprojects/MS-SNSD
duration : 2.0 duration : 2.0
sampling_rate: 16000 sampling_rate: 16000
batch_size: 32 batch_size: 32
valid_size: 0.05 min_valid_minutes: 15
files: files:
train_clean : CleanSpeech_training train_clean : CleanSpeech_training
test_clean : CleanSpeech_training test_clean : CleanSpeech_training

View File

@ -1,6 +1,8 @@
import math import math
import multiprocessing import multiprocessing
import os import os
import sys
import warnings
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -80,6 +82,21 @@ class TaskDataset(pl.LightningDataModule):
self._validation = [] self._validation = []
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
if (
num_workers > 0
and sys.platform == "darwin"
and sys.version_info[0] >= 3
and sys.version_info[1] >= 8
):
warnings.warn(
"num_workers > 0 is not supported with macOS and Python 3.8+: "
"setting num_workers = 0."
)
num_workers = 0
self.num_workers = num_workers self.num_workers = num_workers
if min_valid_minutes > 0.0: if min_valid_minutes > 0.0:
self.min_valid_minutes = min_valid_minutes self.min_valid_minutes = min_valid_minutes

View File

@ -93,7 +93,7 @@ class Fileprocessor:
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None: if matching_function is None:
if name.lower() == "vctk": if name.lower() in ("vctk", "valentini"):
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020": elif name.lower() == "dns-2020":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)

View File

@ -1,4 +1,4 @@
import logging import warnings
import numpy as np import numpy as np
import torch import torch
@ -134,7 +134,7 @@ class Pesq:
try: try:
pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze()))
except Exception as e: except Exception as e:
logging.warning(f"{e} error occured while calculating PESQ") warnings.warn(f"{e} error occured while calculating PESQ")
return torch.tensor(np.mean(pesq_values)) return torch.tensor(np.mean(pesq_values))

View File

@ -1,4 +1,4 @@
import logging import warnings
from typing import Any, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import torch import torch
@ -140,11 +140,11 @@ class DCCRN(Mayamodel):
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None dataset.duration if isinstance(dataset, MayaDataset) else duration
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
logging.warning( warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate

View File

@ -1,5 +1,5 @@
import logging
import math import math
import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import torch.nn.functional as F import torch.nn.functional as F
@ -136,16 +136,17 @@ class Demucs(Mayamodel):
normalize=True, normalize=True,
lr: float = 1e-3, lr: float = 1e-3,
dataset: Optional[MayaDataset] = None, dataset: Optional[MayaDataset] = None,
duration: Optional[float] = None,
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
floor=1e-3, floor=1e-3,
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None dataset.duration if isinstance(dataset, MayaDataset) else duration
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
logging.warning( warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate

View File

@ -24,7 +24,7 @@ CACHE_DIR = os.getenv(
) )
HF_TORCH_WEIGHTS = "pytorch_model.ckpt" HF_TORCH_WEIGHTS = "pytorch_model.ckpt"
DEFAULT_DEVICE = "cpu" DEFAULT_DEVICE = "cpu"
SAVE_NAME = "enhancer" SAVE_NAME = "mayavoz"
class Mayamodel(pl.LightningModule): class Mayamodel(pl.LightningModule):

View File

@ -1,4 +1,4 @@
import logging import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
@ -103,11 +103,11 @@ class WaveUnet(Mayamodel):
metric: Union[str, List] = "mse", metric: Union[str, List] = "mse",
): ):
duration = ( duration = (
dataset.duration if isinstance(dataset, MayaDataset) else None dataset.duration if isinstance(dataset, MayaDataset) else duration
) )
if dataset is not None: if dataset is not None:
if sampling_rate != dataset.sampling_rate: if sampling_rate != dataset.sampling_rate:
logging.warning( warnings.warn(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
) )
sampling_rate = dataset.sampling_rate sampling_rate = dataset.sampling_rate

File diff suppressed because one or more lines are too long

View File

@ -15,7 +15,9 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) dataset = MayaDataset(
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset return dataset

View File

@ -15,7 +15,9 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) dataset = MayaDataset(
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset return dataset

View File

@ -15,7 +15,9 @@ def vctk_dataset():
test_clean="clean_testset_wav", test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav", test_noisy="noisy_testset_wav",
) )
dataset = MayaDataset(name="vctk", root_dir=root_dir, files=files) dataset = MayaDataset(
name="vctk", root_dir=root_dir, files=files, sampling_rate=16000
)
return dataset return dataset