Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-07 10:43:34 +05:30
commit e90efe3163
10 changed files with 191 additions and 13 deletions

View File

@ -37,13 +37,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
sudo apt-get install libsndfile1
pip install -r requirements.txt
pip install black flake8 pytest-cov
pip install black pytest-cov
- name: Install enhancer
run: |
pip install -e .[dev,testing]
- name: Run black
run:
black --check .
- name: Run flake8
run: flake8
black --check . --exclude enhancer/version.py
- name: Test with pytest
run:
pytest tests --cov=enhancer/

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
#local
*.ckpt
cli/train_config/dataset/Vctk_local.yaml
.DS_Store
outputs/

View File

@ -1 +1 @@
__version__ = "0.0.1"
__import__("pkg_resources").declare_namespace(__name__)

View File

@ -139,7 +139,9 @@ class Inference:
if filename.is_file():
raise FileExistsError(f"file {filename} already exists")
else:
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
wavfile.write(
filename, rate=sr, data=waveform.detach().cpu().numpy()
)
@staticmethod
def prepare_output(

View File

@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url
from pytorch_lightning.utilities.cloud_io import load as pl_load
from torch.optim import Adam
from enhancer import __version__
from enhancer.data.dataset import EnhancerDataset
from enhancer.inference import Inference
from enhancer.loss import Avergeloss
from enhancer.version import __version__
CACHE_DIR = ""
HF_TORCH_WEIGHTS = ""
@ -120,7 +120,11 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target)
if self.logger:
if (
(self.logger)
and (self.global_step > 50)
and (self.global_step % 50 == 0)
):
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="train_loss",
@ -141,7 +145,11 @@ class Model(pl.LightningModule):
self.log("val_metric", metric_val.item())
self.log("val_loss", loss_val.item())
if self.logger:
if (
(self.logger)
and (self.global_step > 50)
and (self.global_step % 50 == 0)
):
self.logger.experiment.log_metric(
run_id=self.logger.run_id,
key="val_loss",
@ -209,8 +217,7 @@ class Model(pl.LightningModule):
to True or to a string containing your hugginface.co authentication
token that can be obtained by running `huggingface-cli login`
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
Path to model cache directory
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
@ -290,10 +297,9 @@ class Model(pl.LightningModule):
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
batch_predictions = []
self.eval().to(self.device)
with torch.no_grad():
for batch_id in range(0, batch.shape[0], batch_size):
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
self.device
)
prediction = self(batch_data)

1
enhancer/version.py Normal file
View File

@ -0,0 +1 @@
__version__ = "0.0.1"

View File

@ -6,9 +6,11 @@ librosa>=0.9.2
mlflow>=1.29.0
numpy>=1.23.3
protobuf>=3.19.6
pytest-lazy-fixture>=0.6.3
pytorch-lightning>=1.7.7
scikit-learn>=1.1.2
scipy>=1.9.1
soundfile>=0.11.0
torch>=1.12.1
torchaudio>=0.12.1
tqdm>=4.64.1

100
setup.cfg Normal file
View File

@ -0,0 +1,100 @@
# This file is used to configure your project.
# Read more about the various options under:
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata]
name = enhancer
description = Deep learning for speech enhacement
author = Shahul Ess
author-email = shahules786@gmail.com
license = mit
long-description = file: README.md
long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
# Change if running only on Windows, Mac or Linux (comma-separated)
platforms = Linux, Mac
# Add here all kinds of additional classifiers as defined under
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers =
Development Status :: 4 - Beta
Programming Language :: Python
[options]
zip_safe = False
packages = find:
include_package_data = True
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
setup_requires = setuptools
# Add here dependencies of your project (semicolon/line-separated), e.g.
# install_requires = numpy; scipy
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
python_requires = >=3.8
[options.packages.find]
where = .
exclude =
tests
[options.extras_require]
# Add here additional requirements for extra features, to install with:
# `pip install fastaudio[PDF]` like:
# PDF = ReportLab; RXP
# Add here test requirements (semicolon/line-separated)
testing =
pytest>=7.1.3
pytest-cov>=4.0.0
dev =
pre-commit>=2.20.0
black>=22.8.0
flake8>=5.0.4
cli =
hydra-core >=1.1,<=1.2
[options.entry_points]
console_scripts =
enhancer-train=enhancer.cli.train:train
[test]
# py.test options when running `python setup.py test`
# addopts = --verbose
extras = True
[tool:pytest]
# Options for py.test:
# Specify command line options as you would do when invoking py.test directly.
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
# in order to write a coverage file that can be read by Jenkins.
addopts =
--cov enhancer --cov-report term-missing
--verbose
norecursedirs =
dist
build
.tox
testpaths = tests
[aliases]
dists = bdist_wheel
[bdist_wheel]
# Use this option if your package is pure-python
universal = 1
[build_sphinx]
source_dir = doc
build_dir = build/sphinx
[devpi:upload]
# Options for the devpi: PyPI server and packaging tool
# VCS export must be deactivated since we are using setuptools-scm
no-vcs = 1
formats = bdist_wheel
[flake8]
# Some sane defaults for the code style checker flake8
exclude =
.tox
build
dist
.eggs

63
setup.py Normal file
View File

@ -0,0 +1,63 @@
import os
import sys
from pathlib import Path
from pkg_resources import VersionConflict, require
from setuptools import find_packages, setup
with open("README.md") as f:
long_description = f.read()
with open("requirements.txt") as f:
requirements = f.read().splitlines()
try:
require("setuptools>=38.3")
except VersionConflict:
print("Error: version of setuptools is too old (<38.3)!")
sys.exit(1)
ROOT_DIR = Path(__file__).parent.resolve()
# Creating the version file
with open("version.txt") as f:
version = f.read()
version = version.strip()
sha = "Unknown"
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
print("-- Building version " + version)
version_path = ROOT_DIR / "enhancer" / "version.py"
with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version))
if __name__ == "__main__":
setup(
name="enhancer",
namespace_packages=["enhancer"],
version=version,
packages=find_packages(),
install_requires=requirements,
description="Deep learning toolkit for speech enhancement",
long_description=long_description,
long_description_content_type="text/markdown",
author="Shahul Es",
author_email="shahules786@gmail.com",
url="",
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering",
],
)

1
version.txt Normal file
View File

@ -0,0 +1 @@
0.0.1