Merge pull request #13 from shahules786/dev-reformat
flake8 cli/ tests/
This commit is contained in:
		
						commit
						1b30d7f3ed
					
				
							
								
								
									
										67
									
								
								cli/train.py
								
								
								
								
							
							
						
						
									
										67
									
								
								cli/train.py
								
								
								
								
							|  | @ -1,67 +0,0 @@ | |||
| from genericpath import isfile | ||||
| import os | ||||
| from types import MethodType | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID","0") | ||||
| 
 | ||||
| @hydra.main(config_path="train_config",config_name="config") | ||||
| def main(config: DictConfig): | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, | ||||
|                             run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID}) | ||||
| 
 | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
| 
 | ||||
|     dataset = instantiate(config.dataset) | ||||
|     model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"), | ||||
|             loss=parameters.get("loss"), metric = parameters.get("metric")) | ||||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True, | ||||
|         mode=direction,every_n_epochs=1 | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|     early_stopping = EarlyStopping( | ||||
|             monitor="val_loss", | ||||
|             mode=direction, | ||||
|             min_delta=0.0, | ||||
|             patience=parameters.get("EarlyStopping_patience",10), | ||||
|             strict=True, | ||||
|             verbose=False, | ||||
|         ) | ||||
|     callbacks.append(early_stopping) | ||||
|      | ||||
|     def configure_optimizer(self): | ||||
|         optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters()) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|             mode=direction, | ||||
|             factor=parameters.get("ReduceLr_factor",0.1), | ||||
|             verbose=True, | ||||
|             min_lr=parameters.get("min_lr",1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience",3) | ||||
|         ) | ||||
|         return {"optimizer":optimizer, "lr_scheduler":scheduler} | ||||
| 
 | ||||
|     model.configure_parameters = MethodType(configure_optimizer,model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
| 
 | ||||
|     saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id,saved_location) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| if __name__=="__main__": | ||||
|     main() | ||||
|  | @ -0,0 +1,84 @@ | |||
| import os | ||||
| from types import MethodType | ||||
| import hydra | ||||
| from hydra.utils import instantiate | ||||
| from omegaconf import DictConfig | ||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | ||||
| from pytorch_lightning.loggers import MLFlowLogger | ||||
| 
 | ||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | ||||
| JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||
| 
 | ||||
| 
 | ||||
| @hydra.main(config_path="train_config", config_name="config") | ||||
| def main(config: DictConfig): | ||||
| 
 | ||||
|     callbacks = [] | ||||
|     logger = MLFlowLogger( | ||||
|         experiment_name=config.mlflow.experiment_name, | ||||
|         run_name=config.mlflow.run_name, | ||||
|         tags={"JOB_ID": JOB_ID}, | ||||
|     ) | ||||
| 
 | ||||
|     parameters = config.hyperparameters | ||||
| 
 | ||||
|     dataset = instantiate(config.dataset) | ||||
|     model = instantiate( | ||||
|         config.model, | ||||
|         dataset=dataset, | ||||
|         lr=parameters.get("lr"), | ||||
|         loss=parameters.get("loss"), | ||||
|         metric=parameters.get("metric"), | ||||
|     ) | ||||
| 
 | ||||
|     direction = model.valid_monitor | ||||
|     checkpoint = ModelCheckpoint( | ||||
|         dirpath="./model", | ||||
|         filename=f"model_{JOB_ID}", | ||||
|         monitor="val_loss", | ||||
|         verbose=True, | ||||
|         mode=direction, | ||||
|         every_n_epochs=1, | ||||
|     ) | ||||
|     callbacks.append(checkpoint) | ||||
|     early_stopping = EarlyStopping( | ||||
|         monitor="val_loss", | ||||
|         mode=direction, | ||||
|         min_delta=0.0, | ||||
|         patience=parameters.get("EarlyStopping_patience", 10), | ||||
|         strict=True, | ||||
|         verbose=False, | ||||
|     ) | ||||
|     callbacks.append(early_stopping) | ||||
| 
 | ||||
|     def configure_optimizer(self): | ||||
|         optimizer = instantiate( | ||||
|             config.optimizer, | ||||
|             lr=parameters.get("lr"), | ||||
|             parameters=self.parameters(), | ||||
|         ) | ||||
|         scheduler = ReduceLROnPlateau( | ||||
|             optimizer=optimizer, | ||||
|             mode=direction, | ||||
|             factor=parameters.get("ReduceLr_factor", 0.1), | ||||
|             verbose=True, | ||||
|             min_lr=parameters.get("min_lr", 1e-6), | ||||
|             patience=parameters.get("ReduceLr_patience", 3), | ||||
|         ) | ||||
|         return {"optimizer": optimizer, "lr_scheduler": scheduler} | ||||
| 
 | ||||
|     model.configure_parameters = MethodType(configure_optimizer, model) | ||||
| 
 | ||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||
|     trainer.fit(model) | ||||
| 
 | ||||
|     saved_location = os.path.join( | ||||
|         trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt" | ||||
|     ) | ||||
|     if os.path.isfile(saved_location): | ||||
|         logger.experiment.log_artifact(logger.run_id, saved_location) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
|  | @ -0,0 +1,13 @@ | |||
| _target_: enhancer.data.dataset.EnhancerDataset | ||||
| name : vctk | ||||
| root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk | ||||
| duration : 1.0 | ||||
| sampling_rate: 16000 | ||||
| batch_size: 64 | ||||
| num_workers : 0 | ||||
| 
 | ||||
| files: | ||||
|   train_clean : clean_testset_wav | ||||
|   test_clean : clean_testset_wav | ||||
|   train_noisy : noisy_testset_wav | ||||
|   test_noisy : noisy_testset_wav | ||||
|  | @ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error | |||
| 
 | ||||
| loss_functions = [mean_absolute_error(), mean_squared_error()] | ||||
| 
 | ||||
| 
 | ||||
| def check_loss_shapes_compatibility(loss_fun): | ||||
| 
 | ||||
|     batch_size = 4 | ||||
|     shape = (1,1000) | ||||
|     loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) | ||||
|     shape = (1, 1000) | ||||
|     loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape)) | ||||
| 
 | ||||
|     with pytest.raises(TypeError): | ||||
|         loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) | ||||
|         loss_fun(torch.rand(4, *shape), torch.rand(6, *shape)) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("loss",loss_functions) | ||||
| @pytest.mark.parametrize("loss", loss_functions) | ||||
| def test_loss_input_shapes(loss): | ||||
|     check_loss_shapes_compatibility(loss) | ||||
| 
 | ||||
| @pytest.mark.parametrize("loss",loss_functions) | ||||
| 
 | ||||
| @pytest.mark.parametrize("loss", loss_functions) | ||||
| def test_loss_output_type(loss): | ||||
| 
 | ||||
|     batch_size = 4 | ||||
|     prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) | ||||
|     prediction, target = torch.rand(batch_size, 1, 1000), torch.rand( | ||||
|         batch_size, 1, 1000 | ||||
|     ) | ||||
|     loss_value = loss(prediction, target) | ||||
|     assert isinstance(loss_value.item(),float) | ||||
| 
 | ||||
| 
 | ||||
|     assert isinstance(loss_value.item(), float) | ||||
|  |  | |||
|  | @ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset | |||
| @pytest.fixture | ||||
| def vctk_dataset(): | ||||
|     root_dir = "tests/data/vctk" | ||||
|     files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", | ||||
|          test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") | ||||
|     dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) | ||||
|     files = Files( | ||||
|         train_clean="clean_testset_wav", | ||||
|         train_noisy="noisy_testset_wav", | ||||
|         test_clean="clean_testset_wav", | ||||
|         test_noisy="noisy_testset_wav", | ||||
|     ) | ||||
|     dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) | ||||
|     return dataset | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("batch_size,samples",[(1,1000)]) | ||||
| def test_forward(batch_size,samples): | ||||
| @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) | ||||
| def test_forward(batch_size, samples): | ||||
|     model = Demucs() | ||||
|     model.eval() | ||||
| 
 | ||||
|     data = torch.rand(batch_size,1,samples,requires_grad=False) | ||||
|     data = torch.rand(batch_size, 1, samples, requires_grad=False) | ||||
|     with torch.no_grad(): | ||||
|         _ = model(data) | ||||
| 
 | ||||
|     data = torch.rand(batch_size,2,samples,requires_grad=False) | ||||
|     data = torch.rand(batch_size, 2, samples, requires_grad=False) | ||||
|     with torch.no_grad(): | ||||
|         with pytest.raises(TypeError): | ||||
|             _ = model(data) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("dataset,channels,loss", | ||||
|                         [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) | ||||
| def test_demucs_init(dataset,channels,loss): | ||||
| @pytest.mark.parametrize( | ||||
|     "dataset,channels,loss", | ||||
|     [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], | ||||
| ) | ||||
| def test_demucs_init(dataset, channels, loss): | ||||
|     with torch.no_grad(): | ||||
|         model = Demucs(num_channels=channels,dataset=dataset,loss=loss) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|      | ||||
| 
 | ||||
|         model = Demucs(num_channels=channels, dataset=dataset, loss=loss) | ||||
|  |  | |||
|  | @ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset | |||
| @pytest.fixture | ||||
| def vctk_dataset(): | ||||
|     root_dir = "tests/data/vctk" | ||||
|     files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", | ||||
|          test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") | ||||
|     dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) | ||||
|     files = Files( | ||||
|         train_clean="clean_testset_wav", | ||||
|         train_noisy="noisy_testset_wav", | ||||
|         test_clean="clean_testset_wav", | ||||
|         test_noisy="noisy_testset_wav", | ||||
|     ) | ||||
|     dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files) | ||||
|     return dataset | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("batch_size,samples",[(1,1000)]) | ||||
| def test_forward(batch_size,samples): | ||||
| @pytest.mark.parametrize("batch_size,samples", [(1, 1000)]) | ||||
| def test_forward(batch_size, samples): | ||||
|     model = WaveUnet() | ||||
|     model.eval() | ||||
| 
 | ||||
|     data = torch.rand(batch_size,1,samples,requires_grad=False) | ||||
|     data = torch.rand(batch_size, 1, samples, requires_grad=False) | ||||
|     with torch.no_grad(): | ||||
|         _ = model(data) | ||||
| 
 | ||||
|     data = torch.rand(batch_size,2,samples,requires_grad=False) | ||||
|     data = torch.rand(batch_size, 2, samples, requires_grad=False) | ||||
|     with torch.no_grad(): | ||||
|         with pytest.raises(TypeError): | ||||
|             _ = model(data) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("dataset,channels,loss", | ||||
|                         [(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) | ||||
| def test_demucs_init(dataset,channels,loss): | ||||
| @pytest.mark.parametrize( | ||||
|     "dataset,channels,loss", | ||||
|     [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])], | ||||
| ) | ||||
| def test_demucs_init(dataset, channels, loss): | ||||
|     with torch.no_grad(): | ||||
|         model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|      | ||||
| 
 | ||||
|         model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss) | ||||
|  |  | |||
|  | @ -4,22 +4,26 @@ import torch | |||
| from enhancer.inference import Inference | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) | ||||
| @pytest.mark.parametrize( | ||||
|     "audio", | ||||
|     ["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)], | ||||
| ) | ||||
| def test_read_input(audio): | ||||
| 
 | ||||
|     read_audio = Inference.read_input(audio,48000,16000) | ||||
|     assert isinstance(read_audio,torch.Tensor) | ||||
|     read_audio = Inference.read_input(audio, 48000, 16000) | ||||
|     assert isinstance(read_audio, torch.Tensor) | ||||
|     assert read_audio.shape[0] == 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_batchify(): | ||||
|     rand = torch.rand(1,1000) | ||||
|     batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) | ||||
|     rand = torch.rand(1, 1000) | ||||
|     batched_rand = Inference.batchify(rand, window_size=100, step_size=100) | ||||
|     assert batched_rand.shape[0] == 12 | ||||
| 
 | ||||
| 
 | ||||
| def test_aggregate(): | ||||
|     rand = torch.rand(12,1,100) | ||||
|     agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) | ||||
|     rand = torch.rand(12, 1, 100) | ||||
|     agg_rand = Inference.aggreagate( | ||||
|         data=rand, window_size=100, total_frames=1000, step_size=100 | ||||
|     ) | ||||
|     assert agg_rand.shape[-1] == 1000 | ||||
| 
 | ||||
| 
 | ||||
|      | ||||
|  | @ -7,40 +7,46 @@ from enhancer.utils.io import Audio | |||
| from enhancer.utils.config import Files | ||||
| from enhancer.data.fileprocessor import Fileprocessor | ||||
| 
 | ||||
| 
 | ||||
| def test_io_channel(): | ||||
| 
 | ||||
|     input_audio = np.random.rand(2,32000) | ||||
|     audio = Audio(mono=True,return_tensor=False) | ||||
|     input_audio = np.random.rand(2, 32000) | ||||
|     audio = Audio(mono=True, return_tensor=False) | ||||
|     output_audio = audio(input_audio) | ||||
|     assert output_audio.shape[0] == 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_io_resampling(): | ||||
| 
 | ||||
|     input_audio = np.random.rand(1,32000) | ||||
|     resampled_audio = Audio.resample_audio(input_audio,16000,8000) | ||||
|     input_audio = np.random.rand(1, 32000) | ||||
|     resampled_audio = Audio.resample_audio(input_audio, 16000, 8000) | ||||
| 
 | ||||
|     input_audio = torch.rand(1,32000) | ||||
|     resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) | ||||
|     input_audio = torch.rand(1, 32000) | ||||
|     resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000) | ||||
| 
 | ||||
|     assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 | ||||
| 
 | ||||
| 
 | ||||
| def test_fileprocessor_vctk(): | ||||
| 
 | ||||
|     fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", | ||||
|                                 "tests/data/vctk/noisy_testset_wav",48000) | ||||
|     fp = Fileprocessor.from_name( | ||||
|         "vctk", | ||||
|         "tests/data/vctk/clean_testset_wav", | ||||
|         "tests/data/vctk/noisy_testset_wav", | ||||
|         48000, | ||||
|     ) | ||||
|     matching_dict = fp.prepare_matching_dict() | ||||
|     assert len(matching_dict)==2 | ||||
|     assert len(matching_dict) == 2 | ||||
| 
 | ||||
| @pytest.mark.parametrize("dataset_name",["vctk","dns-2020"]) | ||||
| 
 | ||||
| @pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"]) | ||||
| def test_fileprocessor_names(dataset_name): | ||||
|      fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) | ||||
|      assert hasattr(fp.matching_function, '__call__') | ||||
|     fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000) | ||||
|     assert hasattr(fp.matching_function, "__call__") | ||||
| 
 | ||||
| 
 | ||||
| def test_fileprocessor_invaliname(): | ||||
|     with pytest.raises(ValueError): | ||||
|         fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         fp = Fileprocessor.from_name( | ||||
|             "undefined", "clean_dir", "noisy_dir", 16000 | ||||
|         ).prepare_matching_dict() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES