diff --git a/cli/train.py b/cli/train.py deleted file mode 100644 index dee3d2e..0000000 --- a/cli/train.py +++ /dev/null @@ -1,67 +0,0 @@ -from genericpath import isfile -import os -from types import MethodType -import hydra -from hydra.utils import instantiate -from omegaconf import DictConfig -from torch.optim.lr_scheduler import ReduceLROnPlateau -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning.loggers import MLFlowLogger -os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID","0") - -@hydra.main(config_path="train_config",config_name="config") -def main(config: DictConfig): - - callbacks = [] - logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, - run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) - - - parameters = config.hyperparameters - - dataset = instantiate(config.dataset) - model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), - loss=parameters.get("loss"), metric = parameters.get("metric")) - - direction = model.valid_monitor - checkpoint = ModelCheckpoint( - dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, - mode=direction,every_n_epochs=1 - ) - callbacks.append(checkpoint) - early_stopping = EarlyStopping( - monitor="val_loss", - mode=direction, - min_delta=0.0, - patience=parameters.get("EarlyStopping_patience",10), - strict=True, - verbose=False, - ) - callbacks.append(early_stopping) - - def configure_optimizer(self): - optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) - scheduler = ReduceLROnPlateau( - optimizer=optimizer, - mode=direction, - factor=parameters.get("ReduceLr_factor",0.1), - verbose=True, - min_lr=parameters.get("min_lr",1e-6), - patience=parameters.get("ReduceLr_patience",3) - ) - return {"optimizer":optimizer, "lr_scheduler":scheduler} - - model.configure_parameters = MethodType(configure_optimizer,model) - - trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) - trainer.fit(model) - - saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") - if os.path.isfile(saved_location): - logger.experiment.log_artifact(logger.run_id,saved_location) - - - -if __name__=="__main__": - main() \ No newline at end of file diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py new file mode 100644 index 0000000..814fa0f --- /dev/null +++ b/enhancer/cli/train.py @@ -0,0 +1,84 @@ +import os +from types import MethodType +import hydra +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.optim.lr_scheduler import ReduceLROnPlateau +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.loggers import MLFlowLogger + +os.environ["HYDRA_FULL_ERROR"] = "1" +JOB_ID = os.environ.get("SLURM_JOBID", "0") + + +@hydra.main(config_path="train_config", config_name="config") +def main(config: DictConfig): + + callbacks = [] + logger = MLFlowLogger( + experiment_name=config.mlflow.experiment_name, + run_name=config.mlflow.run_name, + tags={"JOB_ID": JOB_ID}, + ) + + parameters = config.hyperparameters + + dataset = instantiate(config.dataset) + model = instantiate( + config.model, + dataset=dataset, + lr=parameters.get("lr"), + loss=parameters.get("loss"), + metric=parameters.get("metric"), + ) + + direction = model.valid_monitor + checkpoint = ModelCheckpoint( + dirpath="./model", + filename=f"model_{JOB_ID}", + monitor="val_loss", + verbose=True, + mode=direction, + every_n_epochs=1, + ) + callbacks.append(checkpoint) + early_stopping = EarlyStopping( + monitor="val_loss", + mode=direction, + min_delta=0.0, + patience=parameters.get("EarlyStopping_patience", 10), + strict=True, + verbose=False, + ) + callbacks.append(early_stopping) + + def configure_optimizer(self): + optimizer = instantiate( + config.optimizer, + lr=parameters.get("lr"), + parameters=self.parameters(), + ) + scheduler = ReduceLROnPlateau( + optimizer=optimizer, + mode=direction, + factor=parameters.get("ReduceLr_factor", 0.1), + verbose=True, + min_lr=parameters.get("min_lr", 1e-6), + patience=parameters.get("ReduceLr_patience", 3), + ) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + model.configure_parameters = MethodType(configure_optimizer, model) + + trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) + trainer.fit(model) + + saved_location = os.path.join( + trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" + ) + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id, saved_location) + + +if __name__ == "__main__": + main() diff --git a/cli/train_config/config.yaml b/enhancer/cli/train_config/config.yaml similarity index 100% rename from cli/train_config/config.yaml rename to enhancer/cli/train_config/config.yaml diff --git a/cli/train_config/dataset/DNS-2020.yaml b/enhancer/cli/train_config/dataset/DNS-2020.yaml similarity index 100% rename from cli/train_config/dataset/DNS-2020.yaml rename to enhancer/cli/train_config/dataset/DNS-2020.yaml diff --git a/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml similarity index 100% rename from cli/train_config/dataset/Vctk.yaml rename to enhancer/cli/train_config/dataset/Vctk.yaml diff --git a/enhancer/cli/train_config/dataset/Vctk_local.yaml b/enhancer/cli/train_config/dataset/Vctk_local.yaml new file mode 100644 index 0000000..b792b71 --- /dev/null +++ b/enhancer/cli/train_config/dataset/Vctk_local.yaml @@ -0,0 +1,13 @@ +_target_: enhancer.data.dataset.EnhancerDataset +name : vctk +root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk +duration : 1.0 +sampling_rate: 16000 +batch_size: 64 +num_workers : 0 + +files: + train_clean : clean_testset_wav + test_clean : clean_testset_wav + train_noisy : noisy_testset_wav + test_noisy : noisy_testset_wav \ No newline at end of file diff --git a/cli/train_config/hyperparameters/default.yaml b/enhancer/cli/train_config/hyperparameters/default.yaml similarity index 100% rename from cli/train_config/hyperparameters/default.yaml rename to enhancer/cli/train_config/hyperparameters/default.yaml diff --git a/cli/train_config/mlflow/experiment.yaml b/enhancer/cli/train_config/mlflow/experiment.yaml similarity index 100% rename from cli/train_config/mlflow/experiment.yaml rename to enhancer/cli/train_config/mlflow/experiment.yaml diff --git a/cli/train_config/model/Demucs.yaml b/enhancer/cli/train_config/model/Demucs.yaml similarity index 100% rename from cli/train_config/model/Demucs.yaml rename to enhancer/cli/train_config/model/Demucs.yaml diff --git a/cli/train_config/model/WaveUnet.yaml b/enhancer/cli/train_config/model/WaveUnet.yaml similarity index 100% rename from cli/train_config/model/WaveUnet.yaml rename to enhancer/cli/train_config/model/WaveUnet.yaml diff --git a/cli/train_config/optimizer/Adam.yaml b/enhancer/cli/train_config/optimizer/Adam.yaml similarity index 100% rename from cli/train_config/optimizer/Adam.yaml rename to enhancer/cli/train_config/optimizer/Adam.yaml diff --git a/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml similarity index 100% rename from cli/train_config/trainer/default.yaml rename to enhancer/cli/train_config/trainer/default.yaml diff --git a/cli/train_config/trainer/fastrun_dev.yaml b/enhancer/cli/train_config/trainer/fastrun_dev.yaml similarity index 100% rename from cli/train_config/trainer/fastrun_dev.yaml rename to enhancer/cli/train_config/trainer/fastrun_dev.yaml diff --git a/tests/loss_function_test.py b/tests/loss_function_test.py index fbc982c..a4fdc62 100644 --- a/tests/loss_function_test.py +++ b/tests/loss_function_test.py @@ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error loss_functions = [mean_absolute_error(), mean_squared_error()] + def check_loss_shapes_compatibility(loss_fun): batch_size = 4 - shape = (1,1000) - loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) + shape = (1, 1000) + loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape)) with pytest.raises(TypeError): - loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) + loss_fun(torch.rand(4, *shape), torch.rand(6, *shape)) -@pytest.mark.parametrize("loss",loss_functions) +@pytest.mark.parametrize("loss", loss_functions) def test_loss_input_shapes(loss): check_loss_shapes_compatibility(loss) -@pytest.mark.parametrize("loss",loss_functions) + +@pytest.mark.parametrize("loss", loss_functions) def test_loss_output_type(loss): batch_size = 4 - prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) + prediction, target = torch.rand(batch_size, 1, 1000), torch.rand( + batch_size, 1, 1000 + ) loss_value = loss(prediction, target) - assert isinstance(loss_value.item(),float) - - + assert isinstance(loss_value.item(), float) diff --git a/tests/models/demucs_test.py b/tests/models/demucs_test.py index a59fa04..6660888 100644 --- a/tests/models/demucs_test.py +++ b/tests/models/demucs_test.py @@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = Demucs() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = Demucs(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + model = Demucs(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/models/test_waveunet.py b/tests/models/test_waveunet.py index 43fd14d..c83966b 100644 --- a/tests/models/test_waveunet.py +++ b/tests/models/test_waveunet.py @@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset @pytest.fixture def vctk_dataset(): root_dir = "tests/data/vctk" - files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", - test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") - dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) + files = Files( + train_clean="clean_testset_wav", + train_noisy="noisy_testset_wav", + test_clean="clean_testset_wav", + test_noisy="noisy_testset_wav", + ) + dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) return dataset - -@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) -def test_forward(batch_size,samples): +@pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) +def test_forward(batch_size, samples): model = WaveUnet() model.eval() - data = torch.rand(batch_size,1,samples,requires_grad=False) + data = torch.rand(batch_size, 1, samples, requires_grad=False) with torch.no_grad(): _ = model(data) - data = torch.rand(batch_size,2,samples,requires_grad=False) + data = torch.rand(batch_size, 2, samples, requires_grad=False) with torch.no_grad(): with pytest.raises(TypeError): _ = model(data) -@pytest.mark.parametrize("dataset,channels,loss", - [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) -def test_demucs_init(dataset,channels,loss): +@pytest.mark.parametrize( + "dataset,channels,loss", + [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], +) +def test_demucs_init(dataset, channels, loss): with torch.no_grad(): - model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) - - - - - - - + model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) diff --git a/tests/test_inference.py b/tests/test_inference.py index 5eb7442..a6e2423 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -4,22 +4,26 @@ import torch from enhancer.inference import Inference -@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) +@pytest.mark.parametrize( + "audio", + ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)], +) def test_read_input(audio): - read_audio = Inference.read_input(audio,48000,16000) - assert isinstance(read_audio,torch.Tensor) + read_audio = Inference.read_input(audio, 48000, 16000) + assert isinstance(read_audio, torch.Tensor) assert read_audio.shape[0] == 1 + def test_batchify(): - rand = torch.rand(1,1000) - batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) + rand = torch.rand(1, 1000) + batched_rand = Inference.batchify(rand, window_size=100, step_size=100) assert batched_rand.shape[0] == 12 + def test_aggregate(): - rand = torch.rand(12,1,100) - agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) + rand = torch.rand(12, 1, 100) + agg_rand = Inference.aggreagate( + data=rand, window_size=100, total_frames=1000, step_size=100 + ) assert agg_rand.shape[-1] == 1000 - - - \ No newline at end of file diff --git a/tests/utils_test.py b/tests/utils_test.py index 413bfac..93a9094 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -7,40 +7,46 @@ from enhancer.utils.io import Audio from enhancer.utils.config import Files from enhancer.data.fileprocessor import Fileprocessor + def test_io_channel(): - input_audio = np.random.rand(2,32000) - audio = Audio(mono=True,return_tensor=False) + input_audio = np.random.rand(2, 32000) + audio = Audio(mono=True, return_tensor=False) output_audio = audio(input_audio) assert output_audio.shape[0] == 1 + def test_io_resampling(): - input_audio = np.random.rand(1,32000) - resampled_audio = Audio.resample_audio(input_audio,16000,8000) + input_audio = np.random.rand(1, 32000) + resampled_audio = Audio.resample_audio(input_audio, 16000, 8000) - input_audio = torch.rand(1,32000) - resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) + input_audio = torch.rand(1, 32000) + resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000) assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 + def test_fileprocessor_vctk(): - fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", - "tests/data/vctk/noisy_testset_wav",48000) + fp = Fileprocessor.from_name( + "vctk", + "tests/data/vctk/clean_testset_wav", + "tests/data/vctk/noisy_testset_wav", + 48000, + ) matching_dict = fp.prepare_matching_dict() - assert len(matching_dict)==2 + assert len(matching_dict) == 2 -@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"]) + +@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) def test_fileprocessor_names(dataset_name): - fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) - assert hasattr(fp.matching_function, '__call__') + fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) + assert hasattr(fp.matching_function, "__call__") + def test_fileprocessor_invaliname(): with pytest.raises(ValueError): - fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() - - - - - + fp = Fileprocessor.from_name( + "undefined", "clean_dir", "noisy_dir", 16000 + ).prepare_matching_dict()