utils
This commit is contained in:
		
							parent
							
								
									86ccdbb5cb
								
							
						
					
					
						commit
						deccda5389
					
				|  | @ -0,0 +1,40 @@ | ||||||
|  | import os | ||||||
|  | import random | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def create_unique_rnge(epoch:int): | ||||||
|  |     """create unique random number generator for each (worker_id,epoch) combination""" | ||||||
|  | 
 | ||||||
|  |     rng = random.Random() | ||||||
|  | 
 | ||||||
|  |     global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) | ||||||
|  |     global_rank = int(os.environ.get('GLOBAL_RANK',"0")) | ||||||
|  |     local_rank = int(os.environ.get('LOCAL_RANK',"0")) | ||||||
|  |     node_rank = int(os.environ.get('NODE_RANK',"0")) | ||||||
|  |     world_size = int(os.environ.get('WORLD_SIZE',"0")) | ||||||
|  | 
 | ||||||
|  |     worker_info = torch.utils.data.get_worker_info() | ||||||
|  |     if worker_info is not None: | ||||||
|  |         num_workers = worker_info.num_workers | ||||||
|  |         worker_id = worker_info.worker_id | ||||||
|  |     else: | ||||||
|  |         num_workers = 1 | ||||||
|  |         worker_id = 0 | ||||||
|  | 
 | ||||||
|  |     seed = ( | ||||||
|  |             global_seed | ||||||
|  |             + worker_id | ||||||
|  |             + local_rank * num_workers | ||||||
|  |             + node_rank * num_workers * global_rank | ||||||
|  |             + epoch * num_workers * world_size | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     rng.seed(seed) | ||||||
|  | 
 | ||||||
|  |     return rng | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
		Loading…
	
		Reference in New Issue
	
	 shahules786
						shahules786