fix randomization

This commit is contained in:
shahules786 2022-10-21 23:22:56 +05:30
parent a75f3c32a3
commit cd9ffc1a68
2 changed files with 12 additions and 9 deletions

View File

@ -1,6 +1,7 @@
import math import math
import multiprocessing import multiprocessing
import os import os
import random
from itertools import chain, cycle from itertools import chain, cycle
from typing import Optional from typing import Optional
@ -120,7 +121,7 @@ class TaskDataset(pl.LightningDataModule):
valid_min_now = 0.0 valid_min_now = 0.0
valid_indices = [] valid_indices = []
random_indices = list(range(0, len(data))) random_indices = list(range(0, len(data)))
rng = create_unique_rng(random_state, 0) rng = create_unique_rng(random_state)
rng.shuffle(random_indices) rng.shuffle(random_indices)
i = 0 i = 0
while valid_min_now <= valid_minutes: while valid_min_now <= valid_minutes:
@ -285,14 +286,15 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage) super().setup(stage=stage)
def random_sample(self, index): def random_sample(self, train_data):
rng = create_unique_rng(self.model.current_epoch, index) return random.sample(train_data, len(train_data))
return rng.sample(self.train_data, len(self.train_data))
def train__iter__(self): def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
train_data = rng.sample(self.train_data, len(self.train_data))
return zip( return zip(
*[ *[
self.get_stream(self.random_sample(i)) self.get_stream(self.random_sample(train_data))
for i in range(self.batch_size) for i in range(self.batch_size)
] ]
) )
@ -353,8 +355,10 @@ class EnhancerDataset(TaskDataset):
} }
def train__len__(self): def train__len__(self):
worker_info = torch.utils.data.get_worker_info() if self.num_workers > 1:
num_workers = worker_info.num_workers if worker_info else 1 num_workers = 2
else:
num_workers = 1
return sum([len(item) for item in self.train_data]) // ( return sum([len(item) for item in self.train_data]) // (
self.batch_size * num_workers self.batch_size * num_workers
) )

View File

@ -4,7 +4,7 @@ import random
import torch import torch
def create_unique_rng(epoch: int, index: int): def create_unique_rng(epoch: int):
"""create unique random number generator for each (worker_id,epoch) combination""" """create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random() rng = random.Random()
@ -29,7 +29,6 @@ def create_unique_rng(epoch: int, index: int):
+ local_rank * num_workers + local_rank * num_workers
+ node_rank * num_workers * global_rank + node_rank * num_workers * global_rank
+ epoch * num_workers * world_size + epoch * num_workers * world_size
+ index
) )
rng.seed(seed) rng.seed(seed)