diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 06a6f67..218cd7c 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -189,8 +189,8 @@ class TaskDataset(pl.LightningDataModule): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset worker_id = worker_info.id - split_size = len(dataset.data) // worker_info.num_workers - dataset.data = dataset.data[ + split_size = len(dataset.dataset.train_data) // worker_info.num_workers + dataset.data = dataset.dataset.train_data[ worker_id * split_size : (worker_id + 1) * split_size ]