Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
e90efe3163
|
|
@ -37,13 +37,15 @@ jobs:
|
|||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
sudo apt-get install libsndfile1
|
||||
pip install -r requirements.txt
|
||||
pip install black flake8 pytest-cov
|
||||
pip install black pytest-cov
|
||||
- name: Install enhancer
|
||||
run: |
|
||||
pip install -e .[dev,testing]
|
||||
- name: Run black
|
||||
run:
|
||||
black --check .
|
||||
- name: Run flake8
|
||||
run: flake8
|
||||
black --check . --exclude enhancer/version.py
|
||||
- name: Test with pytest
|
||||
run:
|
||||
pytest tests --cov=enhancer/
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#local
|
||||
*.ckpt
|
||||
cli/train_config/dataset/Vctk_local.yaml
|
||||
.DS_Store
|
||||
outputs/
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
__version__ = "0.0.1"
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
|
|
|
|||
|
|
@ -139,7 +139,9 @@ class Inference:
|
|||
if filename.is_file():
|
||||
raise FileExistsError(f"file {filename} already exists")
|
||||
else:
|
||||
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
|
||||
wavfile.write(
|
||||
filename, rate=sr, data=waveform.detach().cpu().numpy()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prepare_output(
|
||||
|
|
|
|||
|
|
@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url
|
|||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from torch.optim import Adam
|
||||
|
||||
from enhancer import __version__
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
from enhancer.inference import Inference
|
||||
from enhancer.loss import Avergeloss
|
||||
from enhancer.version import __version__
|
||||
|
||||
CACHE_DIR = ""
|
||||
HF_TORCH_WEIGHTS = ""
|
||||
|
|
@ -120,7 +120,11 @@ class Model(pl.LightningModule):
|
|||
|
||||
loss = self.loss(prediction, target)
|
||||
|
||||
if self.logger:
|
||||
if (
|
||||
(self.logger)
|
||||
and (self.global_step > 50)
|
||||
and (self.global_step % 50 == 0)
|
||||
):
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="train_loss",
|
||||
|
|
@ -141,7 +145,11 @@ class Model(pl.LightningModule):
|
|||
self.log("val_metric", metric_val.item())
|
||||
self.log("val_loss", loss_val.item())
|
||||
|
||||
if self.logger:
|
||||
if (
|
||||
(self.logger)
|
||||
and (self.global_step > 50)
|
||||
and (self.global_step % 50 == 0)
|
||||
):
|
||||
self.logger.experiment.log_metric(
|
||||
run_id=self.logger.run_id,
|
||||
key="val_loss",
|
||||
|
|
@ -209,8 +217,7 @@ class Model(pl.LightningModule):
|
|||
to True or to a string containing your hugginface.co authentication
|
||||
token that can be obtained by running `huggingface-cli login`
|
||||
cache_dir: Path or str, optional
|
||||
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
||||
environment variable, or "~/.cache/torch/pyannote" when unset.
|
||||
Path to model cache directory
|
||||
kwargs: optional
|
||||
Any extra keyword args needed to init the model.
|
||||
Can also be used to override saved hyperparameter values.
|
||||
|
|
@ -290,10 +297,9 @@ class Model(pl.LightningModule):
|
|||
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||
batch_predictions = []
|
||||
self.eval().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0, batch.shape[0], batch_size):
|
||||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
||||
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
|
||||
self.device
|
||||
)
|
||||
prediction = self(batch_data)
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
__version__ = "0.0.1"
|
||||
|
|
@ -6,9 +6,11 @@ librosa>=0.9.2
|
|||
mlflow>=1.29.0
|
||||
numpy>=1.23.3
|
||||
protobuf>=3.19.6
|
||||
pytest-lazy-fixture>=0.6.3
|
||||
pytorch-lightning>=1.7.7
|
||||
scikit-learn>=1.1.2
|
||||
scipy>=1.9.1
|
||||
soundfile>=0.11.0
|
||||
torch>=1.12.1
|
||||
torchaudio>=0.12.1
|
||||
tqdm>=4.64.1
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
# This file is used to configure your project.
|
||||
# Read more about the various options under:
|
||||
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
||||
|
||||
[metadata]
|
||||
name = enhancer
|
||||
description = Deep learning for speech enhacement
|
||||
author = Shahul Ess
|
||||
author-email = shahules786@gmail.com
|
||||
license = mit
|
||||
long-description = file: README.md
|
||||
long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
|
||||
# Change if running only on Windows, Mac or Linux (comma-separated)
|
||||
platforms = Linux, Mac
|
||||
# Add here all kinds of additional classifiers as defined under
|
||||
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||
classifiers =
|
||||
Development Status :: 4 - Beta
|
||||
Programming Language :: Python
|
||||
|
||||
[options]
|
||||
zip_safe = False
|
||||
packages = find:
|
||||
include_package_data = True
|
||||
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
|
||||
setup_requires = setuptools
|
||||
# Add here dependencies of your project (semicolon/line-separated), e.g.
|
||||
# install_requires = numpy; scipy
|
||||
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
|
||||
python_requires = >=3.8
|
||||
|
||||
[options.packages.find]
|
||||
where = .
|
||||
exclude =
|
||||
tests
|
||||
|
||||
[options.extras_require]
|
||||
# Add here additional requirements for extra features, to install with:
|
||||
# `pip install fastaudio[PDF]` like:
|
||||
# PDF = ReportLab; RXP
|
||||
# Add here test requirements (semicolon/line-separated)
|
||||
testing =
|
||||
pytest>=7.1.3
|
||||
pytest-cov>=4.0.0
|
||||
dev =
|
||||
pre-commit>=2.20.0
|
||||
black>=22.8.0
|
||||
flake8>=5.0.4
|
||||
cli =
|
||||
hydra-core >=1.1,<=1.2
|
||||
|
||||
|
||||
[options.entry_points]
|
||||
|
||||
console_scripts =
|
||||
enhancer-train=enhancer.cli.train:train
|
||||
|
||||
[test]
|
||||
# py.test options when running `python setup.py test`
|
||||
# addopts = --verbose
|
||||
extras = True
|
||||
|
||||
[tool:pytest]
|
||||
# Options for py.test:
|
||||
# Specify command line options as you would do when invoking py.test directly.
|
||||
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
||||
# in order to write a coverage file that can be read by Jenkins.
|
||||
addopts =
|
||||
--cov enhancer --cov-report term-missing
|
||||
--verbose
|
||||
norecursedirs =
|
||||
dist
|
||||
build
|
||||
.tox
|
||||
testpaths = tests
|
||||
|
||||
[aliases]
|
||||
dists = bdist_wheel
|
||||
|
||||
[bdist_wheel]
|
||||
# Use this option if your package is pure-python
|
||||
universal = 1
|
||||
|
||||
[build_sphinx]
|
||||
source_dir = doc
|
||||
build_dir = build/sphinx
|
||||
|
||||
[devpi:upload]
|
||||
# Options for the devpi: PyPI server and packaging tool
|
||||
# VCS export must be deactivated since we are using setuptools-scm
|
||||
no-vcs = 1
|
||||
formats = bdist_wheel
|
||||
|
||||
[flake8]
|
||||
# Some sane defaults for the code style checker flake8
|
||||
exclude =
|
||||
.tox
|
||||
build
|
||||
dist
|
||||
.eggs
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from pkg_resources import VersionConflict, require
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
with open("README.md") as f:
|
||||
long_description = f.read()
|
||||
|
||||
with open("requirements.txt") as f:
|
||||
requirements = f.read().splitlines()
|
||||
|
||||
try:
|
||||
require("setuptools>=38.3")
|
||||
except VersionConflict:
|
||||
print("Error: version of setuptools is too old (<38.3)!")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.resolve()
|
||||
# Creating the version file
|
||||
|
||||
with open("version.txt") as f:
|
||||
version = f.read()
|
||||
|
||||
version = version.strip()
|
||||
sha = "Unknown"
|
||||
|
||||
if os.getenv("BUILD_VERSION"):
|
||||
version = os.getenv("BUILD_VERSION")
|
||||
elif sha != "Unknown":
|
||||
version += "+" + sha[:7]
|
||||
print("-- Building version " + version)
|
||||
|
||||
version_path = ROOT_DIR / "enhancer" / "version.py"
|
||||
|
||||
with open(version_path, "w") as f:
|
||||
f.write("__version__ = '{}'\n".format(version))
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup(
|
||||
name="enhancer",
|
||||
namespace_packages=["enhancer"],
|
||||
version=version,
|
||||
packages=find_packages(),
|
||||
install_requires=requirements,
|
||||
description="Deep learning toolkit for speech enhancement",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
author="Shahul Es",
|
||||
author_email="shahules786@gmail.com",
|
||||
url="",
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Topic :: Scientific/Engineering",
|
||||
],
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
0.0.1
|
||||
Loading…
Reference in New Issue