fix randomization
This commit is contained in:
parent
a75f3c32a3
commit
cd9ffc1a68
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
from itertools import chain, cycle
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -120,7 +121,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
valid_min_now = 0.0
|
||||
valid_indices = []
|
||||
random_indices = list(range(0, len(data)))
|
||||
rng = create_unique_rng(random_state, 0)
|
||||
rng = create_unique_rng(random_state)
|
||||
rng.shuffle(random_indices)
|
||||
i = 0
|
||||
while valid_min_now <= valid_minutes:
|
||||
|
|
@ -285,14 +286,15 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
super().setup(stage=stage)
|
||||
|
||||
def random_sample(self, index):
|
||||
rng = create_unique_rng(self.model.current_epoch, index)
|
||||
return rng.sample(self.train_data, len(self.train_data))
|
||||
def random_sample(self, train_data):
|
||||
return random.sample(train_data, len(train_data))
|
||||
|
||||
def train__iter__(self):
|
||||
rng = create_unique_rng(self.model.current_epoch)
|
||||
train_data = rng.sample(self.train_data, len(self.train_data))
|
||||
return zip(
|
||||
*[
|
||||
self.get_stream(self.random_sample(i))
|
||||
self.get_stream(self.random_sample(train_data))
|
||||
for i in range(self.batch_size)
|
||||
]
|
||||
)
|
||||
|
|
@ -353,8 +355,10 @@ class EnhancerDataset(TaskDataset):
|
|||
}
|
||||
|
||||
def train__len__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
num_workers = worker_info.num_workers if worker_info else 1
|
||||
if self.num_workers > 1:
|
||||
num_workers = 2
|
||||
else:
|
||||
num_workers = 1
|
||||
return sum([len(item) for item in self.train_data]) // (
|
||||
self.batch_size * num_workers
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import random
|
|||
import torch
|
||||
|
||||
|
||||
def create_unique_rng(epoch: int, index: int):
|
||||
def create_unique_rng(epoch: int):
|
||||
"""create unique random number generator for each (worker_id,epoch) combination"""
|
||||
|
||||
rng = random.Random()
|
||||
|
|
@ -29,7 +29,6 @@ def create_unique_rng(epoch: int, index: int):
|
|||
+ local_rank * num_workers
|
||||
+ node_rank * num_workers * global_rank
|
||||
+ epoch * num_workers * world_size
|
||||
+ index
|
||||
)
|
||||
|
||||
rng.seed(seed)
|
||||
|
|
|
|||
Loading…
Reference in New Issue