Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-07 10:43:34 +05:30
commit e90efe3163
10 changed files with 191 additions and 13 deletions

View File

@ -37,13 +37,15 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
sudo apt-get install libsndfile1
pip install -r requirements.txt pip install -r requirements.txt
pip install black flake8 pytest-cov pip install black pytest-cov
- name: Install enhancer
run: |
pip install -e .[dev,testing]
- name: Run black - name: Run black
run: run:
black --check . black --check . --exclude enhancer/version.py
- name: Run flake8
run: flake8
- name: Test with pytest - name: Test with pytest
run: run:
pytest tests --cov=enhancer/ pytest tests --cov=enhancer/

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
#local #local
*.ckpt
cli/train_config/dataset/Vctk_local.yaml cli/train_config/dataset/Vctk_local.yaml
.DS_Store .DS_Store
outputs/ outputs/

View File

@ -1 +1 @@
__version__ = "0.0.1" __import__("pkg_resources").declare_namespace(__name__)

View File

@ -139,7 +139,9 @@ class Inference:
if filename.is_file(): if filename.is_file():
raise FileExistsError(f"file {filename} already exists") raise FileExistsError(f"file {filename} already exists")
else: else:
wavfile.write(filename, rate=sr, data=waveform.detach().cpu()) wavfile.write(
filename, rate=sr, data=waveform.detach().cpu().numpy()
)
@staticmethod @staticmethod
def prepare_output( def prepare_output(

View File

@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch.optim import Adam from torch.optim import Adam
from enhancer import __version__
from enhancer.data.dataset import EnhancerDataset from enhancer.data.dataset import EnhancerDataset
from enhancer.inference import Inference from enhancer.inference import Inference
from enhancer.loss import Avergeloss from enhancer.loss import Avergeloss
from enhancer.version import __version__
CACHE_DIR = "" CACHE_DIR = ""
HF_TORCH_WEIGHTS = "" HF_TORCH_WEIGHTS = ""
@ -120,7 +120,11 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target) loss = self.loss(prediction, target)
if self.logger: if (
(self.logger)
and (self.global_step > 50)
and (self.global_step % 50 == 0)
):
self.logger.experiment.log_metric( self.logger.experiment.log_metric(
run_id=self.logger.run_id, run_id=self.logger.run_id,
key="train_loss", key="train_loss",
@ -141,7 +145,11 @@ class Model(pl.LightningModule):
self.log("val_metric", metric_val.item()) self.log("val_metric", metric_val.item())
self.log("val_loss", loss_val.item()) self.log("val_loss", loss_val.item())
if self.logger: if (
(self.logger)
and (self.global_step > 50)
and (self.global_step % 50 == 0)
):
self.logger.experiment.log_metric( self.logger.experiment.log_metric(
run_id=self.logger.run_id, run_id=self.logger.run_id,
key="val_loss", key="val_loss",
@ -209,8 +217,7 @@ class Model(pl.LightningModule):
to True or to a string containing your hugginface.co authentication to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login` token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE Path to model cache directory
environment variable, or "~/.cache/torch/pyannote" when unset.
kwargs: optional kwargs: optional
Any extra keyword args needed to init the model. Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values. Can also be used to override saved hyperparameter values.
@ -290,10 +297,9 @@ class Model(pl.LightningModule):
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}" ), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = [] batch_predictions = []
self.eval().to(self.device) self.eval().to(self.device)
with torch.no_grad(): with torch.no_grad():
for batch_id in range(0, batch.shape[0], batch_size): for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : batch_id + batch_size, :, :].to( batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
self.device self.device
) )
prediction = self(batch_data) prediction = self(batch_data)

1
enhancer/version.py Normal file
View File

@ -0,0 +1 @@
__version__ = "0.0.1"

View File

@ -6,9 +6,11 @@ librosa>=0.9.2
mlflow>=1.29.0 mlflow>=1.29.0
numpy>=1.23.3 numpy>=1.23.3
protobuf>=3.19.6 protobuf>=3.19.6
pytest-lazy-fixture>=0.6.3
pytorch-lightning>=1.7.7 pytorch-lightning>=1.7.7
scikit-learn>=1.1.2 scikit-learn>=1.1.2
scipy>=1.9.1 scipy>=1.9.1
soundfile>=0.11.0
torch>=1.12.1 torch>=1.12.1
torchaudio>=0.12.1 torchaudio>=0.12.1
tqdm>=4.64.1 tqdm>=4.64.1

100
setup.cfg Normal file
View File

@ -0,0 +1,100 @@
# This file is used to configure your project.
# Read more about the various options under:
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata]
name = enhancer
description = Deep learning for speech enhacement
author = Shahul Ess
author-email = shahules786@gmail.com
license = mit
long-description = file: README.md
long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
# Change if running only on Windows, Mac or Linux (comma-separated)
platforms = Linux, Mac
# Add here all kinds of additional classifiers as defined under
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers =
Development Status :: 4 - Beta
Programming Language :: Python
[options]
zip_safe = False
packages = find:
include_package_data = True
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
setup_requires = setuptools
# Add here dependencies of your project (semicolon/line-separated), e.g.
# install_requires = numpy; scipy
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
python_requires = >=3.8
[options.packages.find]
where = .
exclude =
tests
[options.extras_require]
# Add here additional requirements for extra features, to install with:
# `pip install fastaudio[PDF]` like:
# PDF = ReportLab; RXP
# Add here test requirements (semicolon/line-separated)
testing =
pytest>=7.1.3
pytest-cov>=4.0.0
dev =
pre-commit>=2.20.0
black>=22.8.0
flake8>=5.0.4
cli =
hydra-core >=1.1,<=1.2
[options.entry_points]
console_scripts =
enhancer-train=enhancer.cli.train:train
[test]
# py.test options when running `python setup.py test`
# addopts = --verbose
extras = True
[tool:pytest]
# Options for py.test:
# Specify command line options as you would do when invoking py.test directly.
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
# in order to write a coverage file that can be read by Jenkins.
addopts =
--cov enhancer --cov-report term-missing
--verbose
norecursedirs =
dist
build
.tox
testpaths = tests
[aliases]
dists = bdist_wheel
[bdist_wheel]
# Use this option if your package is pure-python
universal = 1
[build_sphinx]
source_dir = doc
build_dir = build/sphinx
[devpi:upload]
# Options for the devpi: PyPI server and packaging tool
# VCS export must be deactivated since we are using setuptools-scm
no-vcs = 1
formats = bdist_wheel
[flake8]
# Some sane defaults for the code style checker flake8
exclude =
.tox
build
dist
.eggs

63
setup.py Normal file
View File

@ -0,0 +1,63 @@
import os
import sys
from pathlib import Path
from pkg_resources import VersionConflict, require
from setuptools import find_packages, setup
with open("README.md") as f:
long_description = f.read()
with open("requirements.txt") as f:
requirements = f.read().splitlines()
try:
require("setuptools>=38.3")
except VersionConflict:
print("Error: version of setuptools is too old (<38.3)!")
sys.exit(1)
ROOT_DIR = Path(__file__).parent.resolve()
# Creating the version file
with open("version.txt") as f:
version = f.read()
version = version.strip()
sha = "Unknown"
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
print("-- Building version " + version)
version_path = ROOT_DIR / "enhancer" / "version.py"
with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version))
if __name__ == "__main__":
setup(
name="enhancer",
namespace_packages=["enhancer"],
version=version,
packages=find_packages(),
install_requires=requirements,
description="Deep learning toolkit for speech enhancement",
long_description=long_description,
long_description_content_type="text/markdown",
author="Shahul Es",
author_email="shahules786@gmail.com",
url="",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
],
)

1
version.txt Normal file
View File

@ -0,0 +1 @@
0.0.1