reformat utils
This commit is contained in:
		
							parent
							
								
									aca4521ef2
								
							
						
					
					
						commit
						6c4bced360
					
				|  | @ -1,10 +1,9 @@ | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class Files: | class Files: | ||||||
|     train_clean : str |     train_clean: str | ||||||
|     train_noisy : str |     train_noisy: str | ||||||
|     test_clean : str |     test_clean: str | ||||||
|     test_noisy : str |     test_noisy: str | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -3,17 +3,16 @@ import random | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | def create_unique_rng(epoch: int): | ||||||
| def create_unique_rng(epoch:int): |  | ||||||
|     """create unique random number generator for each (worker_id,epoch) combination""" |     """create unique random number generator for each (worker_id,epoch) combination""" | ||||||
| 
 | 
 | ||||||
|     rng = random.Random() |     rng = random.Random() | ||||||
| 
 | 
 | ||||||
|     global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) |     global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) | ||||||
|     global_rank = int(os.environ.get('GLOBAL_RANK',"0")) |     global_rank = int(os.environ.get("GLOBAL_RANK", "0")) | ||||||
|     local_rank = int(os.environ.get('LOCAL_RANK',"0")) |     local_rank = int(os.environ.get("LOCAL_RANK", "0")) | ||||||
|     node_rank = int(os.environ.get('NODE_RANK',"0")) |     node_rank = int(os.environ.get("NODE_RANK", "0")) | ||||||
|     world_size = int(os.environ.get('WORLD_SIZE',"0")) |     world_size = int(os.environ.get("WORLD_SIZE", "0")) | ||||||
| 
 | 
 | ||||||
|     worker_info = torch.utils.data.get_worker_info() |     worker_info = torch.utils.data.get_worker_info() | ||||||
|     if worker_info is not None: |     if worker_info is not None: | ||||||
|  | @ -24,17 +23,13 @@ def create_unique_rng(epoch:int): | ||||||
|         worker_id = 0 |         worker_id = 0 | ||||||
| 
 | 
 | ||||||
|     seed = ( |     seed = ( | ||||||
|             global_seed |         global_seed | ||||||
|             + worker_id |         + worker_id | ||||||
|             + local_rank * num_workers |         + local_rank * num_workers | ||||||
|             + node_rank * num_workers * global_rank |         + node_rank * num_workers * global_rank | ||||||
|             + epoch * num_workers * world_size |         + epoch * num_workers * world_size | ||||||
|         ) |     ) | ||||||
| 
 | 
 | ||||||
|     rng.seed(seed) |     rng.seed(seed) | ||||||
| 
 | 
 | ||||||
|     return rng |     return rng | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -1,19 +1,24 @@ | ||||||
| 
 |  | ||||||
| import os | import os | ||||||
| from typing import Optional | from typing import Optional | ||||||
| from enhancer.utils.config import Files | from enhancer.utils.config import Files | ||||||
| 
 | 
 | ||||||
| def check_files(root_dir:str, files:Files): |  | ||||||
| 
 | 
 | ||||||
|     path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] | def check_files(root_dir: str, files: Files): | ||||||
|  | 
 | ||||||
|  |     path_variables = [ | ||||||
|  |         member_var | ||||||
|  |         for member_var in dir(files) | ||||||
|  |         if not member_var.startswith("__") | ||||||
|  |     ] | ||||||
|     for variable in path_variables: |     for variable in path_variables: | ||||||
|         path = getattr(files,variable) |         path = getattr(files, variable) | ||||||
|         if not os.path.isdir(os.path.join(root_dir,path)): |         if not os.path.isdir(os.path.join(root_dir, path)): | ||||||
|             raise ValueError(f"Invalid {path}, is not a directory") |             raise ValueError(f"Invalid {path}, is not a directory") | ||||||
| 
 | 
 | ||||||
|     return files,root_dir |     return files, root_dir | ||||||
| 
 | 
 | ||||||
| def merge_dict(default_dict:dict, custom:Optional[dict]=None): | 
 | ||||||
|  | def merge_dict(default_dict: dict, custom: Optional[dict] = None): | ||||||
|     params = dict(default_dict) |     params = dict(default_dict) | ||||||
|     if custom: |     if custom: | ||||||
|         params.update(custom) |         params.update(custom) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786