debug
This commit is contained in:
parent
c4a27686da
commit
7fa54fc414
|
|
@ -355,14 +355,12 @@ class EnhancerDataset(TaskDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
if self.num_workers > 1:
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
num_workers = 2
|
if worker_info is None:
|
||||||
|
train_data = self.train_data
|
||||||
else:
|
else:
|
||||||
num_workers = 1
|
train_data = worker_info.dataset.data
|
||||||
print("num_workers", num_workers)
|
return sum([len(item) for item in train_data]) // (self.batch_size)
|
||||||
return sum([len(item) for item in self.train_data]) // (
|
|
||||||
self.batch_size * self.num_workers
|
|
||||||
)
|
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue