utils
This commit is contained in:
parent
86ccdbb5cb
commit
deccda5389
|
|
@ -0,0 +1,40 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_unique_rnge(epoch:int):
|
||||||
|
"""create unique random number generator for each (worker_id,epoch) combination"""
|
||||||
|
|
||||||
|
rng = random.Random()
|
||||||
|
|
||||||
|
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0"))
|
||||||
|
global_rank = int(os.environ.get('GLOBAL_RANK',"0"))
|
||||||
|
local_rank = int(os.environ.get('LOCAL_RANK',"0"))
|
||||||
|
node_rank = int(os.environ.get('NODE_RANK',"0"))
|
||||||
|
world_size = int(os.environ.get('WORLD_SIZE',"0"))
|
||||||
|
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
|
if worker_info is not None:
|
||||||
|
num_workers = worker_info.num_workers
|
||||||
|
worker_id = worker_info.worker_id
|
||||||
|
else:
|
||||||
|
num_workers = 1
|
||||||
|
worker_id = 0
|
||||||
|
|
||||||
|
seed = (
|
||||||
|
global_seed
|
||||||
|
+ worker_id
|
||||||
|
+ local_rank * num_workers
|
||||||
|
+ node_rank * num_workers * global_rank
|
||||||
|
+ epoch * num_workers * world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
rng.seed(seed)
|
||||||
|
|
||||||
|
return rng
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue