Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
e90efe3163
|
|
@ -37,13 +37,15 @@ jobs:
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
sudo apt-get install libsndfile1
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
pip install black flake8 pytest-cov
|
pip install black pytest-cov
|
||||||
|
- name: Install enhancer
|
||||||
|
run: |
|
||||||
|
pip install -e .[dev,testing]
|
||||||
- name: Run black
|
- name: Run black
|
||||||
run:
|
run:
|
||||||
black --check .
|
black --check . --exclude enhancer/version.py
|
||||||
- name: Run flake8
|
|
||||||
run: flake8
|
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run:
|
run:
|
||||||
pytest tests --cov=enhancer/
|
pytest tests --cov=enhancer/
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#local
|
#local
|
||||||
|
*.ckpt
|
||||||
cli/train_config/dataset/Vctk_local.yaml
|
cli/train_config/dataset/Vctk_local.yaml
|
||||||
.DS_Store
|
.DS_Store
|
||||||
outputs/
|
outputs/
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
__version__ = "0.0.1"
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
|
|
||||||
|
|
@ -139,7 +139,9 @@ class Inference:
|
||||||
if filename.is_file():
|
if filename.is_file():
|
||||||
raise FileExistsError(f"file {filename} already exists")
|
raise FileExistsError(f"file {filename} already exists")
|
||||||
else:
|
else:
|
||||||
wavfile.write(filename, rate=sr, data=waveform.detach().cpu())
|
wavfile.write(
|
||||||
|
filename, rate=sr, data=waveform.detach().cpu().numpy()
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_output(
|
def prepare_output(
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,10 @@ from huggingface_hub import cached_download, hf_hub_url
|
||||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
from enhancer import __version__
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
from enhancer.inference import Inference
|
from enhancer.inference import Inference
|
||||||
from enhancer.loss import Avergeloss
|
from enhancer.loss import Avergeloss
|
||||||
|
from enhancer.version import __version__
|
||||||
|
|
||||||
CACHE_DIR = ""
|
CACHE_DIR = ""
|
||||||
HF_TORCH_WEIGHTS = ""
|
HF_TORCH_WEIGHTS = ""
|
||||||
|
|
@ -120,7 +120,11 @@ class Model(pl.LightningModule):
|
||||||
|
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
if self.logger:
|
if (
|
||||||
|
(self.logger)
|
||||||
|
and (self.global_step > 50)
|
||||||
|
and (self.global_step % 50 == 0)
|
||||||
|
):
|
||||||
self.logger.experiment.log_metric(
|
self.logger.experiment.log_metric(
|
||||||
run_id=self.logger.run_id,
|
run_id=self.logger.run_id,
|
||||||
key="train_loss",
|
key="train_loss",
|
||||||
|
|
@ -141,7 +145,11 @@ class Model(pl.LightningModule):
|
||||||
self.log("val_metric", metric_val.item())
|
self.log("val_metric", metric_val.item())
|
||||||
self.log("val_loss", loss_val.item())
|
self.log("val_loss", loss_val.item())
|
||||||
|
|
||||||
if self.logger:
|
if (
|
||||||
|
(self.logger)
|
||||||
|
and (self.global_step > 50)
|
||||||
|
and (self.global_step % 50 == 0)
|
||||||
|
):
|
||||||
self.logger.experiment.log_metric(
|
self.logger.experiment.log_metric(
|
||||||
run_id=self.logger.run_id,
|
run_id=self.logger.run_id,
|
||||||
key="val_loss",
|
key="val_loss",
|
||||||
|
|
@ -209,8 +217,7 @@ class Model(pl.LightningModule):
|
||||||
to True or to a string containing your hugginface.co authentication
|
to True or to a string containing your hugginface.co authentication
|
||||||
token that can be obtained by running `huggingface-cli login`
|
token that can be obtained by running `huggingface-cli login`
|
||||||
cache_dir: Path or str, optional
|
cache_dir: Path or str, optional
|
||||||
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
|
Path to model cache directory
|
||||||
environment variable, or "~/.cache/torch/pyannote" when unset.
|
|
||||||
kwargs: optional
|
kwargs: optional
|
||||||
Any extra keyword args needed to init the model.
|
Any extra keyword args needed to init the model.
|
||||||
Can also be used to override saved hyperparameter values.
|
Can also be used to override saved hyperparameter values.
|
||||||
|
|
@ -290,10 +297,9 @@ class Model(pl.LightningModule):
|
||||||
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
), f"Expected batch with 3 dimensions (batch,channels,samples) got only {batch.ndim}"
|
||||||
batch_predictions = []
|
batch_predictions = []
|
||||||
self.eval().to(self.device)
|
self.eval().to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_id in range(0, batch.shape[0], batch_size):
|
for batch_id in range(0, batch.shape[0], batch_size):
|
||||||
batch_data = batch[batch_id : batch_id + batch_size, :, :].to(
|
batch_data = batch[batch_id : (batch_id + batch_size), :, :].to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
prediction = self(batch_data)
|
prediction = self(batch_data)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
__version__ = "0.0.1"
|
||||||
|
|
@ -6,9 +6,11 @@ librosa>=0.9.2
|
||||||
mlflow>=1.29.0
|
mlflow>=1.29.0
|
||||||
numpy>=1.23.3
|
numpy>=1.23.3
|
||||||
protobuf>=3.19.6
|
protobuf>=3.19.6
|
||||||
|
pytest-lazy-fixture>=0.6.3
|
||||||
pytorch-lightning>=1.7.7
|
pytorch-lightning>=1.7.7
|
||||||
scikit-learn>=1.1.2
|
scikit-learn>=1.1.2
|
||||||
scipy>=1.9.1
|
scipy>=1.9.1
|
||||||
|
soundfile>=0.11.0
|
||||||
torch>=1.12.1
|
torch>=1.12.1
|
||||||
torchaudio>=0.12.1
|
torchaudio>=0.12.1
|
||||||
tqdm>=4.64.1
|
tqdm>=4.64.1
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
# This file is used to configure your project.
|
||||||
|
# Read more about the various options under:
|
||||||
|
# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
|
||||||
|
|
||||||
|
[metadata]
|
||||||
|
name = enhancer
|
||||||
|
description = Deep learning for speech enhacement
|
||||||
|
author = Shahul Ess
|
||||||
|
author-email = shahules786@gmail.com
|
||||||
|
license = mit
|
||||||
|
long-description = file: README.md
|
||||||
|
long-description-content-type = text/markdown; charset=UTF-8; variant=GFM
|
||||||
|
# Change if running only on Windows, Mac or Linux (comma-separated)
|
||||||
|
platforms = Linux, Mac
|
||||||
|
# Add here all kinds of additional classifiers as defined under
|
||||||
|
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||||
|
classifiers =
|
||||||
|
Development Status :: 4 - Beta
|
||||||
|
Programming Language :: Python
|
||||||
|
|
||||||
|
[options]
|
||||||
|
zip_safe = False
|
||||||
|
packages = find:
|
||||||
|
include_package_data = True
|
||||||
|
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
|
||||||
|
setup_requires = setuptools
|
||||||
|
# Add here dependencies of your project (semicolon/line-separated), e.g.
|
||||||
|
# install_requires = numpy; scipy
|
||||||
|
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
|
||||||
|
python_requires = >=3.8
|
||||||
|
|
||||||
|
[options.packages.find]
|
||||||
|
where = .
|
||||||
|
exclude =
|
||||||
|
tests
|
||||||
|
|
||||||
|
[options.extras_require]
|
||||||
|
# Add here additional requirements for extra features, to install with:
|
||||||
|
# `pip install fastaudio[PDF]` like:
|
||||||
|
# PDF = ReportLab; RXP
|
||||||
|
# Add here test requirements (semicolon/line-separated)
|
||||||
|
testing =
|
||||||
|
pytest>=7.1.3
|
||||||
|
pytest-cov>=4.0.0
|
||||||
|
dev =
|
||||||
|
pre-commit>=2.20.0
|
||||||
|
black>=22.8.0
|
||||||
|
flake8>=5.0.4
|
||||||
|
cli =
|
||||||
|
hydra-core >=1.1,<=1.2
|
||||||
|
|
||||||
|
|
||||||
|
[options.entry_points]
|
||||||
|
|
||||||
|
console_scripts =
|
||||||
|
enhancer-train=enhancer.cli.train:train
|
||||||
|
|
||||||
|
[test]
|
||||||
|
# py.test options when running `python setup.py test`
|
||||||
|
# addopts = --verbose
|
||||||
|
extras = True
|
||||||
|
|
||||||
|
[tool:pytest]
|
||||||
|
# Options for py.test:
|
||||||
|
# Specify command line options as you would do when invoking py.test directly.
|
||||||
|
# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml
|
||||||
|
# in order to write a coverage file that can be read by Jenkins.
|
||||||
|
addopts =
|
||||||
|
--cov enhancer --cov-report term-missing
|
||||||
|
--verbose
|
||||||
|
norecursedirs =
|
||||||
|
dist
|
||||||
|
build
|
||||||
|
.tox
|
||||||
|
testpaths = tests
|
||||||
|
|
||||||
|
[aliases]
|
||||||
|
dists = bdist_wheel
|
||||||
|
|
||||||
|
[bdist_wheel]
|
||||||
|
# Use this option if your package is pure-python
|
||||||
|
universal = 1
|
||||||
|
|
||||||
|
[build_sphinx]
|
||||||
|
source_dir = doc
|
||||||
|
build_dir = build/sphinx
|
||||||
|
|
||||||
|
[devpi:upload]
|
||||||
|
# Options for the devpi: PyPI server and packaging tool
|
||||||
|
# VCS export must be deactivated since we are using setuptools-scm
|
||||||
|
no-vcs = 1
|
||||||
|
formats = bdist_wheel
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
# Some sane defaults for the code style checker flake8
|
||||||
|
exclude =
|
||||||
|
.tox
|
||||||
|
build
|
||||||
|
dist
|
||||||
|
.eggs
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pkg_resources import VersionConflict, require
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
with open("README.md") as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
|
with open("requirements.txt") as f:
|
||||||
|
requirements = f.read().splitlines()
|
||||||
|
|
||||||
|
try:
|
||||||
|
require("setuptools>=38.3")
|
||||||
|
except VersionConflict:
|
||||||
|
print("Error: version of setuptools is too old (<38.3)!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parent.resolve()
|
||||||
|
# Creating the version file
|
||||||
|
|
||||||
|
with open("version.txt") as f:
|
||||||
|
version = f.read()
|
||||||
|
|
||||||
|
version = version.strip()
|
||||||
|
sha = "Unknown"
|
||||||
|
|
||||||
|
if os.getenv("BUILD_VERSION"):
|
||||||
|
version = os.getenv("BUILD_VERSION")
|
||||||
|
elif sha != "Unknown":
|
||||||
|
version += "+" + sha[:7]
|
||||||
|
print("-- Building version " + version)
|
||||||
|
|
||||||
|
version_path = ROOT_DIR / "enhancer" / "version.py"
|
||||||
|
|
||||||
|
with open(version_path, "w") as f:
|
||||||
|
f.write("__version__ = '{}'\n".format(version))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
setup(
|
||||||
|
name="enhancer",
|
||||||
|
namespace_packages=["enhancer"],
|
||||||
|
version=version,
|
||||||
|
packages=find_packages(),
|
||||||
|
install_requires=requirements,
|
||||||
|
description="Deep learning toolkit for speech enhancement",
|
||||||
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/markdown",
|
||||||
|
author="Shahul Es",
|
||||||
|
author_email="shahules786@gmail.com",
|
||||||
|
url="",
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Science/Research",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Topic :: Scientific/Engineering",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
0.0.1
|
||||||
Loading…
Reference in New Issue