fix randomization

This commit is contained in:
shahules786 2022-10-21 23:22:56 +05:30
parent a75f3c32a3
commit cd9ffc1a68
2 changed files with 12 additions and 9 deletions

View File

@ -1,6 +1,7 @@
import math
import multiprocessing
import os
import random
from itertools import chain, cycle
from typing import Optional
@ -120,7 +121,7 @@ class TaskDataset(pl.LightningDataModule):
valid_min_now = 0.0
valid_indices = []
random_indices = list(range(0, len(data)))
rng = create_unique_rng(random_state, 0)
rng = create_unique_rng(random_state)
rng.shuffle(random_indices)
i = 0
while valid_min_now <= valid_minutes:
@ -285,14 +286,15 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage)
def random_sample(self, index):
rng = create_unique_rng(self.model.current_epoch, index)
return rng.sample(self.train_data, len(self.train_data))
def random_sample(self, train_data):
return random.sample(train_data, len(train_data))
def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
train_data = rng.sample(self.train_data, len(self.train_data))
return zip(
*[
self.get_stream(self.random_sample(i))
self.get_stream(self.random_sample(train_data))
for i in range(self.batch_size)
]
)
@ -353,8 +355,10 @@ class EnhancerDataset(TaskDataset):
}
def train__len__(self):
worker_info = torch.utils.data.get_worker_info()
num_workers = worker_info.num_workers if worker_info else 1
if self.num_workers > 1:
num_workers = 2
else:
num_workers = 1
return sum([len(item) for item in self.train_data]) // (
self.batch_size * num_workers
)

View File

@ -4,7 +4,7 @@ import random
import torch
def create_unique_rng(epoch: int, index: int):
def create_unique_rng(epoch: int):
"""create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random()
@ -29,7 +29,6 @@ def create_unique_rng(epoch: int, index: int):
+ local_rank * num_workers
+ node_rank * num_workers * global_rank
+ epoch * num_workers * world_size
+ index
)
rng.seed(seed)