fix randomization
This commit is contained in:
parent
a75f3c32a3
commit
cd9ffc1a68
|
|
@ -1,6 +1,7 @@
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from itertools import chain, cycle
|
from itertools import chain, cycle
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -120,7 +121,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
valid_min_now = 0.0
|
valid_min_now = 0.0
|
||||||
valid_indices = []
|
valid_indices = []
|
||||||
random_indices = list(range(0, len(data)))
|
random_indices = list(range(0, len(data)))
|
||||||
rng = create_unique_rng(random_state, 0)
|
rng = create_unique_rng(random_state)
|
||||||
rng.shuffle(random_indices)
|
rng.shuffle(random_indices)
|
||||||
i = 0
|
i = 0
|
||||||
while valid_min_now <= valid_minutes:
|
while valid_min_now <= valid_minutes:
|
||||||
|
|
@ -285,14 +286,15 @@ class EnhancerDataset(TaskDataset):
|
||||||
|
|
||||||
super().setup(stage=stage)
|
super().setup(stage=stage)
|
||||||
|
|
||||||
def random_sample(self, index):
|
def random_sample(self, train_data):
|
||||||
rng = create_unique_rng(self.model.current_epoch, index)
|
return random.sample(train_data, len(train_data))
|
||||||
return rng.sample(self.train_data, len(self.train_data))
|
|
||||||
|
|
||||||
def train__iter__(self):
|
def train__iter__(self):
|
||||||
|
rng = create_unique_rng(self.model.current_epoch)
|
||||||
|
train_data = rng.sample(self.train_data, len(self.train_data))
|
||||||
return zip(
|
return zip(
|
||||||
*[
|
*[
|
||||||
self.get_stream(self.random_sample(i))
|
self.get_stream(self.random_sample(train_data))
|
||||||
for i in range(self.batch_size)
|
for i in range(self.batch_size)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
@ -353,8 +355,10 @@ class EnhancerDataset(TaskDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
if self.num_workers > 1:
|
||||||
num_workers = worker_info.num_workers if worker_info else 1
|
num_workers = 2
|
||||||
|
else:
|
||||||
|
num_workers = 1
|
||||||
return sum([len(item) for item in self.train_data]) // (
|
return sum([len(item) for item in self.train_data]) // (
|
||||||
self.batch_size * num_workers
|
self.batch_size * num_workers
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import random
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def create_unique_rng(epoch: int, index: int):
|
def create_unique_rng(epoch: int):
|
||||||
"""create unique random number generator for each (worker_id,epoch) combination"""
|
"""create unique random number generator for each (worker_id,epoch) combination"""
|
||||||
|
|
||||||
rng = random.Random()
|
rng = random.Random()
|
||||||
|
|
@ -29,7 +29,6 @@ def create_unique_rng(epoch: int, index: int):
|
||||||
+ local_rank * num_workers
|
+ local_rank * num_workers
|
||||||
+ node_rank * num_workers * global_rank
|
+ node_rank * num_workers * global_rank
|
||||||
+ epoch * num_workers * world_size
|
+ epoch * num_workers * world_size
|
||||||
+ index
|
|
||||||
)
|
)
|
||||||
|
|
||||||
rng.seed(seed)
|
rng.seed(seed)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue