Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
4f6ccadf4b
|
|
@ -0,0 +1,49 @@
|
||||||
|
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||||
|
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
|
name: Enhancer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ dev ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ dev ]
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: [3.8]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v1.1.1
|
||||||
|
env :
|
||||||
|
ACTIONS_ALLOW_UNSECURE_COMMANDS : true
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Cache pip
|
||||||
|
uses: actions/cache@v1
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip # This path is specific to Ubuntu
|
||||||
|
# Look to see if there is a cache hit for the corresponding requirements file
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-pip-
|
||||||
|
${{ runner.os }}-
|
||||||
|
# You can test your matrix by printing the current Python version
|
||||||
|
- name: Display Python version
|
||||||
|
run: python -c "import sys; print(sys.version)"
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install black flake8 pytest-cov
|
||||||
|
- name: Run black
|
||||||
|
run:
|
||||||
|
black --check .
|
||||||
|
- name: Run flake8
|
||||||
|
run: flake8
|
||||||
|
- name: Test with pytest
|
||||||
|
run:
|
||||||
|
pytest tests --cov=enhancer/
|
||||||
|
|
@ -98,6 +98,10 @@ class Fileprocessor:
|
||||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||||
elif name.lower() == "dns-2020":
|
elif name.lower() == "dns-2020":
|
||||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if matching_function not in MATCHING_FNS:
|
if matching_function not in MATCHING_FNS:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ class Inference:
|
||||||
window_size: int,
|
window_size: int,
|
||||||
total_frames: int,
|
total_frames: int,
|
||||||
step_size: Optional[int] = None,
|
step_size: Optional[int] = None,
|
||||||
window="hanning",
|
window="hamming",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
stitch batched waveform into single waveform. (Overlap-add)
|
stitch batched waveform into single waveform. (Overlap-add)
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ class Demucs(Model):
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
|
|
@ -74,9 +74,9 @@ class Model(pl.LightningModule):
|
||||||
def loss(self, loss):
|
def loss(self, loss):
|
||||||
|
|
||||||
if isinstance(loss, str):
|
if isinstance(loss, str):
|
||||||
losses = [loss]
|
loss = [loss]
|
||||||
|
|
||||||
self._loss = Avergeloss(losses)
|
self._loss = Avergeloss(loss)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric(self):
|
def metric(self):
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ class WaveUnet(Model):
|
||||||
)
|
)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
if sampling_rate != dataset.sampling_rate:
|
if sampling_rate != dataset.sampling_rate:
|
||||||
logging.warn(
|
logging.warning(
|
||||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||||
)
|
)
|
||||||
sampling_rate = dataset.sampling_rate
|
sampling_rate = dataset.sampling_rate
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
black>=22.8.0
|
|
||||||
boto3>=1.24.86
|
boto3>=1.24.86
|
||||||
flake8>=5.0.4
|
|
||||||
huggingface-hub>=0.10.0
|
huggingface-hub>=0.10.0
|
||||||
hydra-core>=1.2.0
|
hydra-core>=1.2.0
|
||||||
joblib>=1.2.0
|
joblib>=1.2.0
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@ def test_fileprocessor_vctk():
|
||||||
"vctk",
|
"vctk",
|
||||||
"tests/data/vctk/clean_testset_wav",
|
"tests/data/vctk/clean_testset_wav",
|
||||||
"tests/data/vctk/noisy_testset_wav",
|
"tests/data/vctk/noisy_testset_wav",
|
||||||
48000,
|
|
||||||
)
|
)
|
||||||
matching_dict = fp.prepare_matching_dict()
|
matching_dict = fp.prepare_matching_dict()
|
||||||
assert len(matching_dict) == 2
|
assert len(matching_dict) == 2
|
||||||
|
|
@ -39,7 +38,7 @@ def test_fileprocessor_vctk():
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
||||||
def test_fileprocessor_names(dataset_name):
|
def test_fileprocessor_names(dataset_name):
|
||||||
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
|
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir")
|
||||||
assert hasattr(fp.matching_function, "__call__")
|
assert hasattr(fp.matching_function, "__call__")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue