diff --git a/.gitignore b/.gitignore index ae420f9..6eb0fe3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ -##local +#local +cli/train_config/dataset/Vctk_local.yaml .DS_Store outputs/ datasets/ diff --git a/cli/train.py b/cli/train.py index 9ed1cd0..88e513a 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,17 +1,17 @@ +import os import hydra from hydra.utils import instantiate from omegaconf import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger -from enhancer.data.dataset import EnhancerDataset @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): callbacks = [] logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name) + run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")}) parameters = config.hyperparameters @@ -21,12 +21,12 @@ def main(config: DictConfig): loss=parameters.get("loss"), metric = parameters.get("metric")) checkpoint = ModelCheckpoint( - dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False, + dirpath="",filename="model",monitor="valid_loss",verbose=False, mode="min",every_n_epochs=1 ) callbacks.append(checkpoint) early_stopping = EarlyStopping( - monitor=parameters.get("loss"), + monitor="valid_loss", mode="min", min_delta=0.0, patience=100, diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 7845b01..61551bd 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,7 +1,7 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default - - trainer : fastrun_dev + - trainer : default - mlflow : experiment \ No newline at end of file diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml index d40f27f..d1c8646 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/cli/train_config/dataset/Vctk.yaml @@ -2,8 +2,8 @@ _target_: enhancer.data.dataset.EnhancerDataset name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 1.0 -sampling_rate: 48000 -batch_size: 32 +sampling_rate: 16000 +batch_size: 8 files: train_clean : clean_trainset_56spk_wav diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index 5cbdcb0..04b099b 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -1,4 +1,4 @@ loss : mse metric : mae -lr : 0.001 -num_epochs : 10 +lr : 0.0001 +num_epochs : 100 diff --git a/cli/train_config/mlflow/experiment.yaml b/cli/train_config/mlflow/experiment.yaml index b64b125..2995c60 100644 --- a/cli/train_config/mlflow/experiment.yaml +++ b/cli/train_config/mlflow/experiment.yaml @@ -1,2 +1,2 @@ -experiment_name : "myexp" -run_name : "myrun" \ No newline at end of file +experiment_name : shahules/enhancer +run_name : baseline \ No newline at end of file diff --git a/cli/train_config/trainer/default.yml b/cli/train_config/trainer/default.yaml similarity index 90% rename from cli/train_config/trainer/default.yml rename to cli/train_config/trainer/default.yaml index eeb5b85..ab4e273 100644 --- a/cli/train_config/trainer/default.yml +++ b/cli/train_config/trainer/default.yaml @@ -1,16 +1,15 @@ -# @package _group_ _target_: pytorch_lightning.Trainer accelerator: auto accumulate_grad_batches: 1 amp_backend: native -auto_lr_find: False +auto_lr_find: True auto_scale_batch_size: False auto_select_gpus: True benchmark: False check_val_every_n_epoch: 1 detect_anomaly: False deterministic: False -devices: auto +devices: -1 enable_checkpointing: True enable_model_summary: True enable_progress_bar: True @@ -23,8 +22,8 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 50 -max_epochs: 1000 +log_every_n_steps: 10 +max_epochs: 100 max_steps: null max_time: null min_epochs: 1 diff --git a/cli/train_config/trainer/fastrun_dev.yaml b/cli/train_config/trainer/fastrun_dev.yaml index 5d0895f..682149e 100644 --- a/cli/train_config/trainer/fastrun_dev.yaml +++ b/cli/train_config/trainer/fastrun_dev.yaml @@ -1,3 +1,2 @@ -# @package _group_ _target_: pytorch_lightning.Trainer fast_dev_run: True diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index f4e7e4a..5749c36 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,6 +1,4 @@ - -from dataclasses import dataclass -import glob +import multiprocessing import math import os import pytorch_lightning as pl @@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule): duration:float=1.0, sampling_rate:int=48000, matching_function = None, - batch_size=32): + batch_size=32, + num_workers:Optional[int]=None): super().__init__() self.name = name @@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule): self.batch_size = batch_size self.matching_function = matching_function self._validation = [] + if num_workers is None: + num_workers = multiprocessing.cpu_count()//2 + self.num_workers = num_workers def setup(self, stage: Optional[str] = None): @@ -64,15 +66,13 @@ class TaskDataset(pl.LightningDataModule): train_clean = os.path.join(self.root_dir,self.files.train_clean) train_noisy = os.path.join(self.root_dir,self.files.train_noisy) fp = Fileprocessor.from_name(self.name,train_clean, - train_noisy,self.sampling_rate, - self.matching_function) + train_noisy, self.matching_function) self.train_data = fp.prepare_matching_dict() val_clean = os.path.join(self.root_dir,self.files.test_clean) val_noisy = os.path.join(self.root_dir,self.files.test_noisy) fp = Fileprocessor.from_name(self.name,val_clean, - val_noisy,self.sampling_rate, - self.matching_function) + val_noisy, self.matching_function) val_data = fp.prepare_matching_dict() for item in val_data: @@ -85,10 +85,10 @@ class TaskDataset(pl.LightningDataModule): self._validation.append(({"clean":clean,"noisy":noisy}, start_time)) def train_dataloader(self): - return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2) + return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) def val_dataloader(self): - return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2) + return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) class EnhancerDataset(TaskDataset): """Dataset object for creating clean-noisy speech enhancement datasets""" @@ -101,7 +101,8 @@ class EnhancerDataset(TaskDataset): duration=1.0, sampling_rate=48000, matching_function=None, - batch_size=32): + batch_size=32, + num_workers:Optional[int]=None): super().__init__( name=name, @@ -110,7 +111,8 @@ class EnhancerDataset(TaskDataset): sampling_rate=sampling_rate, duration=duration, matching_function = matching_function, - batch_size=batch_size + batch_size=batch_size, + num_workers = num_workers, ) diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 4df3e23..f903375 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,12 +1,13 @@ import glob import os +from re import S import numpy as np from scipy.io import wavfile class ProcessorFunctions: @staticmethod - def match_vtck(clean_path,noisy_path,sr): + def match_vtck(clean_path,noisy_path): matching_wavfiles = list() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -18,16 +19,15 @@ class ProcessorFunctions: sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr) and - (sr_noisy==sr)): + (sr_clean==sr_noisy)): matching_wavfiles.append( {"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), - "duration":clean_file.shape[-1]/sr} + "duration":clean_file.shape[-1]/sr_clean} ) return matching_wavfiles @staticmethod - def match_dns2020(clean_path,noisy_path,sr): + def match_dns2020(clean_path,noisy_path): matching_wavfiles = dict() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -38,11 +38,10 @@ class ProcessorFunctions: sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) sr_noisy, noisy_file = wavfile.read(noisy_file) if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr) and - (sr_noisy==sr)): + (sr_clean==sr_noisy)): matching_wavfiles.update( {"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, - "duration":clean_file.shape[-1]/sr} + "duration":clean_file.shape[-1]/sr_clean} ) return matching_wavfiles @@ -54,12 +53,10 @@ class Fileprocessor: self, clean_dir, noisy_dir, - sr = 16000, matching_function = None ): self.clean_dir = clean_dir self.noisy_dir = noisy_dir - self.sr = sr self.matching_function = matching_function @classmethod @@ -67,23 +64,22 @@ class Fileprocessor: name:str, clean_dir, noisy_dir, - sr, matching_function=None ): if name.lower() == "vctk": - return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_vtck) + return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck) elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_dns2020) + return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020) else: - return cls(clean_dir,noisy_dir,sr, matching_function) + return cls(clean_dir,noisy_dir, matching_function) def prepare_matching_dict(self): if self.matching_function is None: raise ValueError("Not a valid matching function") - return self.matching_function(self.clean_dir,self.noisy_dir,self.sr) + return self.matching_function(self.clean_dir,self.noisy_dir) diff --git a/enhancer/inference.py b/enhancer/inference.py index 404fe95..6e9cff7 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -10,7 +10,6 @@ from pathlib import Path from librosa import load as load_audio from enhancer.utils import Audio -from enhancer.utils.config import DEFAULT_DEVICE class Inference: diff --git a/enhancer/models/demucs.py b/enhancer/models/demucs.py index 115f63e..7c9d8ff 100644 --- a/enhancer/models/demucs.py +++ b/enhancer/models/demucs.py @@ -1,9 +1,8 @@ -from base64 import encode -from turtle import forward +import logging from typing import Optional, Union, List from torch import nn import torch.nn.functional as F -import math +import math from enhancer.models.model import Model from enhancer.data.dataset import EnhancerDataset @@ -114,6 +113,10 @@ class Demucs(Model): ): duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + if dataset is not None: + if sampling_rate!=dataset.sampling_rate: + logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + sampling_rate = dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index c4be077..8e607ed 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -1,5 +1,6 @@ from importlib import import_module from huggingface_hub import cached_download, hf_hub_url +import logging import numpy as np import os from typing import Optional, Union, List, Text, Dict, Any @@ -82,8 +83,10 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) if self.logger: - self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step) - + self.logger.experiment.log_metric(run_id=self.logger.run_id, + key="train_loss", value=loss.item(), + step=self.global_step) + self.log("train_loss",loss.item()) return {"loss":loss} def validation_step(self,batch,batch_idx:int): @@ -92,11 +95,20 @@ class Model(pl.LightningModule): target = batch["clean"] prediction = self(mixed_waveform) - loss = self.metric(prediction, target) - if self.logger: - self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step) + metric_val = self.metric(prediction, target) + loss_val = self.loss(prediction, target) + self.log("val_metric",metric_val.item()) + self.log("val_loss",loss_val.item()) - return {"loss":loss} + if self.logger: + self.logger.experiment.log_metric(run_id=self.logger.run_id, + key="val_loss",value=loss_val.item(), + step=self.global_step) + self.logger.experiment.log_metric(run_id=self.logger.run_id, + key="val_metric",value=metric_val.item(), + step=self.global_step) + + return {"loss":loss_val} def on_save_checkpoint(self, checkpoint): diff --git a/enhancer/models/waveunet.py b/enhancer/models/waveunet.py index 89b4bb7..f799352 100644 --- a/enhancer/models/waveunet.py +++ b/enhancer/models/waveunet.py @@ -1,5 +1,4 @@ -from tkinter import wantobjects -import wave +import logging import torch import torch.nn as nn import torch.nn.functional as F @@ -70,6 +69,11 @@ class WaveUnet(Model): loss: Union[str, List] = "mse", metric:Union[str,List] = "mse" ): + duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None + if dataset is not None: + if sampling_rate!=dataset.sampling_rate: + logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}") + sampling_rate = dataset.sampling_rate super().__init__(num_channels=num_channels, sampling_rate=sampling_rate,lr=lr, dataset=dataset,duration=duration,loss=loss, metric=metric @@ -125,7 +129,6 @@ class WaveUnet(Model): for layer,decoder in enumerate(self.decoders): out = F.interpolate(out, scale_factor=2, mode="linear") - print(out.shape,encoder_outputs[layer].shape) out = self.fix_last_dim(out,encoder_outputs[layer]) out = torch.cat([out,encoder_outputs[layer]],dim=1) out = decoder(out) diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index 0aaa2e3..e9af6a0 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -1,18 +1,11 @@ from dataclasses import dataclass -@dataclass -class Paths: - log : str - data : str - @dataclass class Files: + root_dir : str train_clean : str train_noisy : str test_clean : str test_noisy : str -@dataclass -class EnhancerConfig: - path : Paths - files: Files \ No newline at end of file + diff --git a/enhancer/utils/io.py b/enhancer/utils/io.py index 80bd1a4..afc19e8 100644 --- a/enhancer/utils/io.py +++ b/enhancer/utils/io.py @@ -1,7 +1,6 @@ import os import librosa from typing import Optional -from matplotlib.pyplot import axis import numpy as np import torch import torchaudio diff --git a/hpc_entrypoint.sh b/hpc_entrypoint.sh index 2a4e3c7..7372eb9 100644 --- a/hpc_entrypoint.sh +++ b/hpc_entrypoint.sh @@ -22,6 +22,8 @@ echo "Activate Environment" source activate enhancer export TRANSFORMERS_OFFLINE=True export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer +export HYDRA_FULL_ERROR=1 + echo $PYTHONPATH source ~/mlflow_settings.sh diff --git a/requirements.txt b/requirements.txt index c74e46d..e7fcd24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,6 @@ tqdm==4.64.0 mlflow==1.23.1 protobuf==3.19.3 boto3==1.23.9 +torchaudio==0.10.2 huggingface-hub==0.4.0 pytorch-lightning==1.5.10 diff --git a/tests/test_inference.py b/tests/test_inference.py index 303aa67..5eb7442 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,6 +1,5 @@ import pytest import torch -import numpy as np from enhancer.inference import Inference