diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..6b3c716 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,49 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Enhancer + +on: + push: + branches: [ dev ] + pull_request: + branches: [ dev ] +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1.1.1 + env : + ACTIONS_ALLOW_UNSECURE_COMMANDS : true + with: + python-version: ${{ matrix.python-version }} + - name: Cache pip + uses: actions/cache@v1 + with: + path: ~/.cache/pip # This path is specific to Ubuntu + # Look to see if there is a cache hit for the corresponding requirements file + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + ${{ runner.os }}- + # You can test your matrix by printing the current Python version + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install black flake8 pytest-cov + - name: Run black + run: + black --check . + - name: Run flake8 + run: flake8 + - name: Test with pytest + run: + pytest tests --cov=enhancer/ diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 66d4d75..03afc73 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -98,6 +98,10 @@ class Fileprocessor: return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) elif name.lower() == "dns-2020": return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) + else: + raise ValueError( + f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}" + ) else: if matching_function not in MATCHING_FNS: raise ValueError( diff --git a/enhancer/inference.py b/enhancer/inference.py index ae399f1..fd3b518 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -91,7 +91,7 @@ class Inference: window_size: int, total_frames: int, step_size: Optional[int] = None, - window="hanning", + window="hamming", ): """ stitch batched waveform into single waveform. (Overlap-add) diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 65f119d..bf9d429 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -143,7 +143,7 @@ class Demucs(Model): ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warn( + logging.warning( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 39dbe80..6e6b4e1 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -74,9 +74,9 @@ class Model(pl.LightningModule): def loss(self, loss): if isinstance(loss, str): - losses = [loss] + loss = [loss] - self._loss = Avergeloss(losses) + self._loss = Avergeloss(loss) @property def metric(self): diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index ebb4b1f..ea5646a 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -107,7 +107,7 @@ class WaveUnet(Model): ) if dataset is not None: if sampling_rate != dataset.sampling_rate: - logging.warn( + logging.warning( f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}" ) sampling_rate = dataset.sampling_rate diff --git a/requirements.txt b/requirements.txt index a16acec..bb13983 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ -black>=22.8.0 boto3>=1.24.86 -flake8>=5.0.4 huggingface-hub>=0.10.0 hydra-core>=1.2.0 joblib>=1.2.0 diff --git a/tests/utils_test.py b/tests/utils_test.py index 65c723d..cd5240c 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -31,7 +31,6 @@ def test_fileprocessor_vctk(): "vctk", "tests/data/vctk/clean_testset_wav", "tests/data/vctk/noisy_testset_wav", - 48000, ) matching_dict = fp.prepare_matching_dict() assert len(matching_dict) == 2 @@ -39,7 +38,7 @@ def test_fileprocessor_vctk(): @pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) def test_fileprocessor_names(dataset_name): - fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) + fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir") assert hasattr(fp.matching_function, "__call__")