fix worker init fn
This commit is contained in:
parent
ba10719520
commit
178a4523ef
|
|
@ -189,8 +189,8 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
dataset = worker_info.dataset
|
dataset = worker_info.dataset
|
||||||
worker_id = worker_info.id
|
worker_id = worker_info.id
|
||||||
split_size = len(dataset.data) // worker_info.num_workers
|
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
|
||||||
dataset.data = dataset.data[
|
dataset.data = dataset.dataset.train_data[
|
||||||
worker_id * split_size : (worker_id + 1) * split_size
|
worker_id * split_size : (worker_id + 1) * split_size
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue