41 lines
		
	
	
		
			949 B
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			41 lines
		
	
	
		
			949 B
		
	
	
	
		
			Python
		
	
	
	
| import os
 | |
| import random
 | |
| import torch
 | |
| 
 | |
| 
 | |
| 
 | |
| def create_unique_rng(epoch:int):
 | |
|     """create unique random number generator for each (worker_id,epoch) combination"""
 | |
| 
 | |
|     rng = random.Random()
 | |
| 
 | |
|     global_seed = int(os.environ.get("PL_GLOBAL_SEED","0"))
 | |
|     global_rank = int(os.environ.get('GLOBAL_RANK',"0"))
 | |
|     local_rank = int(os.environ.get('LOCAL_RANK',"0"))
 | |
|     node_rank = int(os.environ.get('NODE_RANK',"0"))
 | |
|     world_size = int(os.environ.get('WORLD_SIZE',"0"))
 | |
| 
 | |
|     worker_info = torch.utils.data.get_worker_info()
 | |
|     if worker_info is not None:
 | |
|         num_workers = worker_info.num_workers
 | |
|         worker_id = worker_info.worker_id
 | |
|     else:
 | |
|         num_workers = 1
 | |
|         worker_id = 0
 | |
| 
 | |
|     seed = (
 | |
|             global_seed
 | |
|             + worker_id
 | |
|             + local_rank * num_workers
 | |
|             + node_rank * num_workers * global_rank
 | |
|             + epoch * num_workers * world_size
 | |
|         )
 | |
| 
 | |
|     rng.seed(seed)
 | |
| 
 | |
|     return rng
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 |