diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index a8b3896..d4686b5 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,6 +1,7 @@ import math import multiprocessing import os +import random from itertools import chain, cycle from typing import Optional @@ -120,7 +121,7 @@ class TaskDataset(pl.LightningDataModule): valid_min_now = 0.0 valid_indices = [] random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state, 0) + rng = create_unique_rng(random_state) rng.shuffle(random_indices) i = 0 while valid_min_now <= valid_minutes: @@ -285,14 +286,15 @@ class EnhancerDataset(TaskDataset): super().setup(stage=stage) - def random_sample(self, index): - rng = create_unique_rng(self.model.current_epoch, index) - return rng.sample(self.train_data, len(self.train_data)) + def random_sample(self, train_data): + return random.sample(train_data, len(train_data)) def train__iter__(self): + rng = create_unique_rng(self.model.current_epoch) + train_data = rng.sample(self.train_data, len(self.train_data)) return zip( *[ - self.get_stream(self.random_sample(i)) + self.get_stream(self.random_sample(train_data)) for i in range(self.batch_size) ] ) @@ -353,8 +355,10 @@ class EnhancerDataset(TaskDataset): } def train__len__(self): - worker_info = torch.utils.data.get_worker_info() - num_workers = worker_info.num_workers if worker_info else 1 + if self.num_workers > 1: + num_workers = 2 + else: + num_workers = 1 return sum([len(item) for item in self.train_data]) // ( self.batch_size * num_workers ) diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 2feb581..dd9395a 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -4,7 +4,7 @@ import random import torch -def create_unique_rng(epoch: int, index: int): +def create_unique_rng(epoch: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() @@ -29,7 +29,6 @@ def create_unique_rng(epoch: int, index: int): + local_rank * num_workers + node_rank * num_workers * global_rank + epoch * num_workers * world_size - + index ) rng.seed(seed)