From cd9ffc1a684fe255b3c8a221145e68de9cb410b7 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Fri, 21 Oct 2022 23:22:56 +0530 Subject: [PATCH] fix randomization --- enhancer/data/dataset.py | 18 +++++++++++------- enhancer/utils/random.py | 3 +-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index a8b3896..d4686b5 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,6 +1,7 @@ import math import multiprocessing import os +import random from itertools import chain, cycle from typing import Optional @@ -120,7 +121,7 @@ class TaskDataset(pl.LightningDataModule): valid_min_now = 0.0 valid_indices = [] random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state, 0) + rng = create_unique_rng(random_state) rng.shuffle(random_indices) i = 0 while valid_min_now <= valid_minutes: @@ -285,14 +286,15 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) - def random_sample(self, index): - rng = create_unique_rng(self.model.current_epoch, index) - return rng.sample(self.train_data, len(self.train_data)) + def random_sample(self, train_data): + return random.sample(train_data, len(train_data)) def train__iter__(self): + rng = create_unique_rng(self.model.current_epoch) + train_data = rng.sample(self.train_data, len(self.train_data)) return zip( *[ - self.get_stream(self.random_sample(i)) + self.get_stream(self.random_sample(train_data)) for i in range(self.batch_size) ] ) @@ -353,8 +355,10 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - worker_info = torch.utils.data.get_worker_info() - num_workers = worker_info.num_workers if worker_info else 1 + if self.num_workers > 1: + num_workers = 2 + else: + num_workers = 1 return sum([len(item) for item in self.train_data]) // ( self.batch_size * num_workers ) diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 2feb581..dd9395a 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -4,7 +4,7 @@ import random import torch -def create_unique_rng(epoch: int, index: int): +def create_unique_rng(epoch: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() @@ -29,7 +29,6 @@ def create_unique_rng(epoch: int, index: int): + local_rank * num_workers + node_rank * num_workers * global_rank + epoch * num_workers * world_size - + index ) rng.seed(seed)