Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk

This commit is contained in:
shahules786 2022-10-06 11:49:40 +05:30
commit 4f6ccadf4b
8 changed files with 59 additions and 9 deletions

49
.github/workflows/ci.yaml vendored Normal file
View File

@ -0,0 +1,49 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Enhancer
on:
push:
branches: [ dev ]
pull_request:
branches: [ dev ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1.1.1
env :
ACTIONS_ALLOW_UNSECURE_COMMANDS : true
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v1
with:
path: ~/.cache/pip # This path is specific to Ubuntu
# Look to see if there is a cache hit for the corresponding requirements file
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
# You can test your matrix by printing the current Python version
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install black flake8 pytest-cov
- name: Run black
run:
black --check .
- name: Run flake8
run: flake8
- name: Test with pytest
run:
pytest tests --cov=enhancer/

View File

@ -98,6 +98,10 @@ class Fileprocessor:
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020":
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else:
raise ValueError(
f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
)
else:
if matching_function not in MATCHING_FNS:
raise ValueError(

View File

@ -91,7 +91,7 @@ class Inference:
window_size: int,
total_frames: int,
step_size: Optional[int] = None,
window="hanning",
window="hamming",
):
"""
stitch batched waveform into single waveform. (Overlap-add)

View File

@ -143,7 +143,7 @@ class Demucs(Model):
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warn(
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate

View File

@ -74,9 +74,9 @@ class Model(pl.LightningModule):
def loss(self, loss):
if isinstance(loss, str):
losses = [loss]
loss = [loss]
self._loss = Avergeloss(losses)
self._loss = Avergeloss(loss)
@property
def metric(self):

View File

@ -107,7 +107,7 @@ class WaveUnet(Model):
)
if dataset is not None:
if sampling_rate != dataset.sampling_rate:
logging.warn(
logging.warning(
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
)
sampling_rate = dataset.sampling_rate

View File

@ -1,6 +1,4 @@
black>=22.8.0
boto3>=1.24.86
flake8>=5.0.4
huggingface-hub>=0.10.0
hydra-core>=1.2.0
joblib>=1.2.0

View File

@ -31,7 +31,6 @@ def test_fileprocessor_vctk():
"vctk",
"tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",
48000,
)
matching_dict = fp.prepare_matching_dict()
assert len(matching_dict) == 2
@ -39,7 +38,7 @@ def test_fileprocessor_vctk():
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
def test_fileprocessor_names(dataset_name):
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir")
assert hasattr(fp.matching_function, "__call__")