utils
This commit is contained in:
		
							parent
							
								
									86ccdbb5cb
								
							
						
					
					
						commit
						deccda5389
					
				|  | @ -0,0 +1,40 @@ | |||
| import os | ||||
| import random | ||||
| import torch | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def create_unique_rnge(epoch:int): | ||||
|     """create unique random number generator for each (worker_id,epoch) combination""" | ||||
| 
 | ||||
|     rng = random.Random() | ||||
| 
 | ||||
|     global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) | ||||
|     global_rank = int(os.environ.get('GLOBAL_RANK',"0")) | ||||
|     local_rank = int(os.environ.get('LOCAL_RANK',"0")) | ||||
|     node_rank = int(os.environ.get('NODE_RANK',"0")) | ||||
|     world_size = int(os.environ.get('WORLD_SIZE',"0")) | ||||
| 
 | ||||
|     worker_info = torch.utils.data.get_worker_info() | ||||
|     if worker_info is not None: | ||||
|         num_workers = worker_info.num_workers | ||||
|         worker_id = worker_info.worker_id | ||||
|     else: | ||||
|         num_workers = 1 | ||||
|         worker_id = 0 | ||||
| 
 | ||||
|     seed = ( | ||||
|             global_seed | ||||
|             + worker_id | ||||
|             + local_rank * num_workers | ||||
|             + node_rank * num_workers * global_rank | ||||
|             + epoch * num_workers * world_size | ||||
|         ) | ||||
| 
 | ||||
|     rng.seed(seed) | ||||
| 
 | ||||
|     return rng | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786