Merge pull request #13 from shahules786/dev-reformat
flake8 cli/ tests/
This commit is contained in:
		
						commit
						1b30d7f3ed
					
				
							
								
								
									
										67
									
								
								cli/train.py
								
								
								
								
							
							
						
						
									
										67
									
								
								cli/train.py
								
								
								
								
							|  | @ -1,67 +0,0 @@ | ||||||
| from genericpath import isfile |  | ||||||
| import os |  | ||||||
| from types import MethodType |  | ||||||
| import hydra |  | ||||||
| from hydra.utils import instantiate |  | ||||||
| from omegaconf import DictConfig |  | ||||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau |  | ||||||
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping |  | ||||||
| from pytorch_lightning.loggers import MLFlowLogger |  | ||||||
| os.environ["HYDRA_FULL_ERROR"] = "1" |  | ||||||
| JOB_ID = os.environ.get("SLURM_JOBID","0") |  | ||||||
| 
 |  | ||||||
| @hydra.main(config_path="train_config",config_name="config") |  | ||||||
| def main(config: DictConfig): |  | ||||||
| 
 |  | ||||||
|     callbacks = [] |  | ||||||
|     logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, |  | ||||||
|                             run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     parameters = config.hyperparameters |  | ||||||
| 
 |  | ||||||
|     dataset = instantiate(config.dataset) |  | ||||||
|     model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), |  | ||||||
|             loss=parameters.get("loss"), metric = parameters.get("metric")) |  | ||||||
| 
 |  | ||||||
|     direction = model.valid_monitor |  | ||||||
|     checkpoint = ModelCheckpoint( |  | ||||||
|         dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, |  | ||||||
|         mode=direction,every_n_epochs=1 |  | ||||||
|     ) |  | ||||||
|     callbacks.append(checkpoint) |  | ||||||
|     early_stopping = EarlyStopping( |  | ||||||
|             monitor="val_loss", |  | ||||||
|             mode=direction, |  | ||||||
|             min_delta=0.0, |  | ||||||
|             patience=parameters.get("EarlyStopping_patience",10), |  | ||||||
|             strict=True, |  | ||||||
|             verbose=False, |  | ||||||
|         ) |  | ||||||
|     callbacks.append(early_stopping) |  | ||||||
|      |  | ||||||
|     def configure_optimizer(self): |  | ||||||
|         optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) |  | ||||||
|         scheduler = ReduceLROnPlateau( |  | ||||||
|             optimizer=optimizer, |  | ||||||
|             mode=direction, |  | ||||||
|             factor=parameters.get("ReduceLr_factor",0.1), |  | ||||||
|             verbose=True, |  | ||||||
|             min_lr=parameters.get("min_lr",1e-6), |  | ||||||
|             patience=parameters.get("ReduceLr_patience",3) |  | ||||||
|         ) |  | ||||||
|         return {"optimizer":optimizer, "lr_scheduler":scheduler} |  | ||||||
| 
 |  | ||||||
|     model.configure_parameters = MethodType(configure_optimizer,model) |  | ||||||
| 
 |  | ||||||
|     trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) |  | ||||||
|     trainer.fit(model) |  | ||||||
| 
 |  | ||||||
|     saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") |  | ||||||
|     if os.path.isfile(saved_location): |  | ||||||
|         logger.experiment.log_artifact(logger.run_id,saved_location) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if __name__=="__main__": |  | ||||||
|     main() |  | ||||||
|  | @ -0,0 +1,84 @@ | ||||||
|  | import os | ||||||
|  | from types import MethodType | ||||||
|  | import hydra | ||||||
|  | from hydra.utils import instantiate | ||||||
|  | from omegaconf import DictConfig | ||||||
|  | from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||||
|  | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | ||||||
|  | from pytorch_lightning.loggers import MLFlowLogger | ||||||
|  | 
 | ||||||
|  | os.environ["HYDRA_FULL_ERROR"] = "1" | ||||||
|  | JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @hydra.main(config_path="train_config", config_name="config") | ||||||
|  | def main(config: DictConfig): | ||||||
|  | 
 | ||||||
|  |     callbacks = [] | ||||||
|  |     logger = MLFlowLogger( | ||||||
|  |         experiment_name=config.mlflow.experiment_name, | ||||||
|  |         run_name=config.mlflow.run_name, | ||||||
|  |         tags={"JOB_ID": JOB_ID}, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     parameters = config.hyperparameters | ||||||
|  | 
 | ||||||
|  |     dataset = instantiate(config.dataset) | ||||||
|  |     model = instantiate( | ||||||
|  |         config.model, | ||||||
|  |         dataset=dataset, | ||||||
|  |         lr=parameters.get("lr"), | ||||||
|  |         loss=parameters.get("loss"), | ||||||
|  |         metric=parameters.get("metric"), | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     direction = model.valid_monitor | ||||||
|  |     checkpoint = ModelCheckpoint( | ||||||
|  |         dirpath="./model", | ||||||
|  |         filename=f"model_{JOB_ID}", | ||||||
|  |         monitor="val_loss", | ||||||
|  |         verbose=True, | ||||||
|  |         mode=direction, | ||||||
|  |         every_n_epochs=1, | ||||||
|  |     ) | ||||||
|  |     callbacks.append(checkpoint) | ||||||
|  |     early_stopping = EarlyStopping( | ||||||
|  |         monitor="val_loss", | ||||||
|  |         mode=direction, | ||||||
|  |         min_delta=0.0, | ||||||
|  |         patience=parameters.get("EarlyStopping_patience", 10), | ||||||
|  |         strict=True, | ||||||
|  |         verbose=False, | ||||||
|  |     ) | ||||||
|  |     callbacks.append(early_stopping) | ||||||
|  | 
 | ||||||
|  |     def configure_optimizer(self): | ||||||
|  |         optimizer = instantiate( | ||||||
|  |             config.optimizer, | ||||||
|  |             lr=parameters.get("lr"), | ||||||
|  |             parameters=self.parameters(), | ||||||
|  |         ) | ||||||
|  |         scheduler = ReduceLROnPlateau( | ||||||
|  |             optimizer=optimizer, | ||||||
|  |             mode=direction, | ||||||
|  |             factor=parameters.get("ReduceLr_factor", 0.1), | ||||||
|  |             verbose=True, | ||||||
|  |             min_lr=parameters.get("min_lr", 1e-6), | ||||||
|  |             patience=parameters.get("ReduceLr_patience", 3), | ||||||
|  |         ) | ||||||
|  |         return {"optimizer": optimizer, "lr_scheduler": scheduler} | ||||||
|  | 
 | ||||||
|  |     model.configure_parameters = MethodType(configure_optimizer, model) | ||||||
|  | 
 | ||||||
|  |     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||||
|  |     trainer.fit(model) | ||||||
|  | 
 | ||||||
|  |     saved_location = os.path.join( | ||||||
|  |         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" | ||||||
|  |     ) | ||||||
|  |     if os.path.isfile(saved_location): | ||||||
|  |         logger.experiment.log_artifact(logger.run_id, saved_location) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
|  | @ -0,0 +1,13 @@ | ||||||
|  | _target_: enhancer.data.dataset.EnhancerDataset | ||||||
|  | name : vctk | ||||||
|  | root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk | ||||||
|  | duration : 1.0 | ||||||
|  | sampling_rate: 16000 | ||||||
|  | batch_size: 64 | ||||||
|  | num_workers : 0 | ||||||
|  | 
 | ||||||
|  | files: | ||||||
|  |   train_clean : clean_testset_wav | ||||||
|  |   test_clean : clean_testset_wav | ||||||
|  |   train_noisy : noisy_testset_wav | ||||||
|  |   test_noisy : noisy_testset_wav | ||||||
|  | @ -6,6 +6,7 @@ from enhancer.loss import mean_absolute_error, mean_squared_error | ||||||
| 
 | 
 | ||||||
| loss_functions = [mean_absolute_error(), mean_squared_error()] | loss_functions = [mean_absolute_error(), mean_squared_error()] | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def check_loss_shapes_compatibility(loss_fun): | def check_loss_shapes_compatibility(loss_fun): | ||||||
| 
 | 
 | ||||||
|     batch_size = 4 |     batch_size = 4 | ||||||
|  | @ -20,12 +21,13 @@ def check_loss_shapes_compatibility(loss_fun): | ||||||
| def test_loss_input_shapes(loss): | def test_loss_input_shapes(loss): | ||||||
|     check_loss_shapes_compatibility(loss) |     check_loss_shapes_compatibility(loss) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @pytest.mark.parametrize("loss", loss_functions) | @pytest.mark.parametrize("loss", loss_functions) | ||||||
| def test_loss_output_type(loss): | def test_loss_output_type(loss): | ||||||
| 
 | 
 | ||||||
|     batch_size = 4 |     batch_size = 4 | ||||||
|     prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) |     prediction, target = torch.rand(batch_size, 1, 1000), torch.rand( | ||||||
|  |         batch_size, 1, 1000 | ||||||
|  |     ) | ||||||
|     loss_value = loss(prediction, target) |     loss_value = loss(prediction, target) | ||||||
|     assert isinstance(loss_value.item(), float) |     assert isinstance(loss_value.item(), float) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -10,13 +10,16 @@ from enhancer.data.dataset import EnhancerDataset | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def vctk_dataset(): | def vctk_dataset(): | ||||||
|     root_dir = "tests/data/vctk" |     root_dir = "tests/data/vctk" | ||||||
|     files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", |     files = Files( | ||||||
|          test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") |         train_clean="clean_testset_wav", | ||||||
|  |         train_noisy="noisy_testset_wav", | ||||||
|  |         test_clean="clean_testset_wav", | ||||||
|  |         test_noisy="noisy_testset_wav", | ||||||
|  |     ) | ||||||
|     dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) |     dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) | ||||||
|     return dataset |     return dataset | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) | @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) | ||||||
| def test_forward(batch_size, samples): | def test_forward(batch_size, samples): | ||||||
|     model = Demucs() |     model = Demucs() | ||||||
|  | @ -32,15 +35,10 @@ def test_forward(batch_size,samples): | ||||||
|             _ = model(data) |             _ = model(data) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("dataset,channels,loss", | @pytest.mark.parametrize( | ||||||
|                         [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) |     "dataset,channels,loss", | ||||||
|  |     [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], | ||||||
|  | ) | ||||||
| def test_demucs_init(dataset, channels, loss): | def test_demucs_init(dataset, channels, loss): | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|         model = Demucs(num_channels=channels, dataset=dataset, loss=loss) |         model = Demucs(num_channels=channels, dataset=dataset, loss=loss) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|      |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -10,13 +10,16 @@ from enhancer.data.dataset import EnhancerDataset | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def vctk_dataset(): | def vctk_dataset(): | ||||||
|     root_dir = "tests/data/vctk" |     root_dir = "tests/data/vctk" | ||||||
|     files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", |     files = Files( | ||||||
|          test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") |         train_clean="clean_testset_wav", | ||||||
|  |         train_noisy="noisy_testset_wav", | ||||||
|  |         test_clean="clean_testset_wav", | ||||||
|  |         test_noisy="noisy_testset_wav", | ||||||
|  |     ) | ||||||
|     dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) |     dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) | ||||||
|     return dataset |     return dataset | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) | @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) | ||||||
| def test_forward(batch_size, samples): | def test_forward(batch_size, samples): | ||||||
|     model = WaveUnet() |     model = WaveUnet() | ||||||
|  | @ -32,15 +35,10 @@ def test_forward(batch_size,samples): | ||||||
|             _ = model(data) |             _ = model(data) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("dataset,channels,loss", | @pytest.mark.parametrize( | ||||||
|                         [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) |     "dataset,channels,loss", | ||||||
|  |     [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], | ||||||
|  | ) | ||||||
| def test_demucs_init(dataset, channels, loss): | def test_demucs_init(dataset, channels, loss): | ||||||
|     with torch.no_grad(): |     with torch.no_grad(): | ||||||
|         model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) |         model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|      |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -4,22 +4,26 @@ import torch | ||||||
| from enhancer.inference import Inference | from enhancer.inference import Inference | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) | @pytest.mark.parametrize( | ||||||
|  |     "audio", | ||||||
|  |     ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)], | ||||||
|  | ) | ||||||
| def test_read_input(audio): | def test_read_input(audio): | ||||||
| 
 | 
 | ||||||
|     read_audio = Inference.read_input(audio, 48000, 16000) |     read_audio = Inference.read_input(audio, 48000, 16000) | ||||||
|     assert isinstance(read_audio, torch.Tensor) |     assert isinstance(read_audio, torch.Tensor) | ||||||
|     assert read_audio.shape[0] == 1 |     assert read_audio.shape[0] == 1 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def test_batchify(): | def test_batchify(): | ||||||
|     rand = torch.rand(1, 1000) |     rand = torch.rand(1, 1000) | ||||||
|     batched_rand = Inference.batchify(rand, window_size=100, step_size=100) |     batched_rand = Inference.batchify(rand, window_size=100, step_size=100) | ||||||
|     assert batched_rand.shape[0] == 12 |     assert batched_rand.shape[0] == 12 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def test_aggregate(): | def test_aggregate(): | ||||||
|     rand = torch.rand(12, 1, 100) |     rand = torch.rand(12, 1, 100) | ||||||
|     agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) |     agg_rand = Inference.aggreagate( | ||||||
|  |         data=rand, window_size=100, total_frames=1000, step_size=100 | ||||||
|  |     ) | ||||||
|     assert agg_rand.shape[-1] == 1000 |     assert agg_rand.shape[-1] == 1000 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|      |  | ||||||
|  | @ -7,6 +7,7 @@ from enhancer.utils.io import Audio | ||||||
| from enhancer.utils.config import Files | from enhancer.utils.config import Files | ||||||
| from enhancer.data.fileprocessor import Fileprocessor | from enhancer.data.fileprocessor import Fileprocessor | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def test_io_channel(): | def test_io_channel(): | ||||||
| 
 | 
 | ||||||
|     input_audio = np.random.rand(2, 32000) |     input_audio = np.random.rand(2, 32000) | ||||||
|  | @ -14,6 +15,7 @@ def test_io_channel(): | ||||||
|     output_audio = audio(input_audio) |     output_audio = audio(input_audio) | ||||||
|     assert output_audio.shape[0] == 1 |     assert output_audio.shape[0] == 1 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def test_io_resampling(): | def test_io_resampling(): | ||||||
| 
 | 
 | ||||||
|     input_audio = np.random.rand(1, 32000) |     input_audio = np.random.rand(1, 32000) | ||||||
|  | @ -24,23 +26,27 @@ def test_io_resampling(): | ||||||
| 
 | 
 | ||||||
|     assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 |     assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def test_fileprocessor_vctk(): | def test_fileprocessor_vctk(): | ||||||
| 
 | 
 | ||||||
|     fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", |     fp = Fileprocessor.from_name( | ||||||
|                                 "tests/data/vctk/noisy_testset_wav",48000) |         "vctk", | ||||||
|  |         "tests/data/vctk/clean_testset_wav", | ||||||
|  |         "tests/data/vctk/noisy_testset_wav", | ||||||
|  |         48000, | ||||||
|  |     ) | ||||||
|     matching_dict = fp.prepare_matching_dict() |     matching_dict = fp.prepare_matching_dict() | ||||||
|     assert len(matching_dict) == 2 |     assert len(matching_dict) == 2 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) | @pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) | ||||||
| def test_fileprocessor_names(dataset_name): | def test_fileprocessor_names(dataset_name): | ||||||
|     fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) |     fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) | ||||||
|      assert hasattr(fp.matching_function, '__call__') |     assert hasattr(fp.matching_function, "__call__") | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def test_fileprocessor_invaliname(): | def test_fileprocessor_invaliname(): | ||||||
|     with pytest.raises(ValueError): |     with pytest.raises(ValueError): | ||||||
|         fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() |         fp = Fileprocessor.from_name( | ||||||
| 
 |             "undefined", "clean_dir", "noisy_dir", 16000 | ||||||
| 
 |         ).prepare_matching_dict() | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES