37 lines
		
	
	
		
			920 B
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			37 lines
		
	
	
		
			920 B
		
	
	
	
		
			Python
		
	
	
	
import os
 | 
						|
import random
 | 
						|
 | 
						|
import torch
 | 
						|
 | 
						|
 | 
						|
def create_unique_rng(epoch: int):
 | 
						|
    """create unique random number generator for each (worker_id,epoch) combination"""
 | 
						|
 | 
						|
    rng = random.Random()
 | 
						|
 | 
						|
    global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
 | 
						|
    global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
 | 
						|
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
 | 
						|
    node_rank = int(os.environ.get("NODE_RANK", "0"))
 | 
						|
    world_size = int(os.environ.get("WORLD_SIZE", "0"))
 | 
						|
 | 
						|
    worker_info = torch.utils.data.get_worker_info()
 | 
						|
    if worker_info is not None:
 | 
						|
        num_workers = worker_info.num_workers
 | 
						|
        worker_id = worker_info.id
 | 
						|
    else:
 | 
						|
        num_workers = 1
 | 
						|
        worker_id = 0
 | 
						|
 | 
						|
    seed = (
 | 
						|
        global_seed
 | 
						|
        + worker_id
 | 
						|
        + local_rank * num_workers
 | 
						|
        + node_rank * num_workers * global_rank
 | 
						|
        + epoch * num_workers * world_size
 | 
						|
    )
 | 
						|
 | 
						|
    rng.seed(seed)
 | 
						|
 | 
						|
    return rng
 |