diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 9b512ab..cc89083 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -351,7 +351,7 @@ class EnhancerDataset(TaskDataset): def train__len__(self): - return sum([len(item) for item in self.train_data]) + return sum([len(item) for item in self.train_data]) // self.batch_size def val__len__(self): return len(self._validation)