debug
This commit is contained in:
parent
20c12556ff
commit
a7fb27bb0f
|
|
@ -353,8 +353,11 @@ class EnhancerDataset(TaskDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
|
worker_info = torch.utils.data.get_worker_info()
|
||||||
return sum([len(item) for item in self.train_data]) // self.batch_size
|
num_workers = worker_info.num_workers if worker_info else 1
|
||||||
|
return sum([len(item) for item in self.train_data]) // (
|
||||||
|
self.batch_size * num_workers
|
||||||
|
)
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue