Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
4f6ccadf4b
|
|
@ -0,0 +1,49 @@
|
|||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Enhancer
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ dev ]
|
||||
pull_request:
|
||||
branches: [ dev ]
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v1.1.1
|
||||
env :
|
||||
ACTIONS_ALLOW_UNSECURE_COMMANDS : true
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Cache pip
|
||||
uses: actions/cache@v1
|
||||
with:
|
||||
path: ~/.cache/pip # This path is specific to Ubuntu
|
||||
# Look to see if there is a cache hit for the corresponding requirements file
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
${{ runner.os }}-
|
||||
# You can test your matrix by printing the current Python version
|
||||
- name: Display Python version
|
||||
run: python -c "import sys; print(sys.version)"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
pip install black flake8 pytest-cov
|
||||
- name: Run black
|
||||
run:
|
||||
black --check .
|
||||
- name: Run flake8
|
||||
run: flake8
|
||||
- name: Test with pytest
|
||||
run:
|
||||
pytest tests --cov=enhancer/
|
||||
|
|
@ -98,6 +98,10 @@ class Fileprocessor:
|
|||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||
elif name.lower() == "dns-2020":
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
|
||||
)
|
||||
else:
|
||||
if matching_function not in MATCHING_FNS:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class Inference:
|
|||
window_size: int,
|
||||
total_frames: int,
|
||||
step_size: Optional[int] = None,
|
||||
window="hanning",
|
||||
window="hamming",
|
||||
):
|
||||
"""
|
||||
stitch batched waveform into single waveform. (Overlap-add)
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class Demucs(Model):
|
|||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
|
|
@ -74,9 +74,9 @@ class Model(pl.LightningModule):
|
|||
def loss(self, loss):
|
||||
|
||||
if isinstance(loss, str):
|
||||
losses = [loss]
|
||||
loss = [loss]
|
||||
|
||||
self._loss = Avergeloss(losses)
|
||||
self._loss = Avergeloss(loss)
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class WaveUnet(Model):
|
|||
)
|
||||
if dataset is not None:
|
||||
if sampling_rate != dataset.sampling_rate:
|
||||
logging.warn(
|
||||
logging.warning(
|
||||
f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}"
|
||||
)
|
||||
sampling_rate = dataset.sampling_rate
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
black>=22.8.0
|
||||
boto3>=1.24.86
|
||||
flake8>=5.0.4
|
||||
huggingface-hub>=0.10.0
|
||||
hydra-core>=1.2.0
|
||||
joblib>=1.2.0
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ def test_fileprocessor_vctk():
|
|||
"vctk",
|
||||
"tests/data/vctk/clean_testset_wav",
|
||||
"tests/data/vctk/noisy_testset_wav",
|
||||
48000,
|
||||
)
|
||||
matching_dict = fp.prepare_matching_dict()
|
||||
assert len(matching_dict) == 2
|
||||
|
|
@ -39,7 +38,7 @@ def test_fileprocessor_vctk():
|
|||
|
||||
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
|
||||
def test_fileprocessor_names(dataset_name):
|
||||
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
|
||||
fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir")
|
||||
assert hasattr(fp.matching_function, "__call__")
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue