diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 02a1d3b..28f19a6 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -132,11 +132,14 @@ class TaskDataset(pl.LightningDataModule): for item in data: clean, noisy, total_dur = item.values() if total_dur < self.duration: - continue - num_segments = round(total_dur / self.duration) - for index in range(num_segments): - start_time = index * self.duration - metadata.append(({"clean": clean, "noisy": noisy}, start_time)) + metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) + else: + num_segments = round(total_dur / self.duration) + for index in range(num_segments): + start_time = index * self.duration + metadata.append( + ({"clean": clean, "noisy": noisy}, start_time) + ) return metadata def train_dataloader(self): @@ -195,6 +198,7 @@ class EnhancerDataset(TaskDataset): files: Files, valid_minutes=5.0, duration=1.0, + stride=0.5, sampling_rate=48000, matching_function=None, batch_size=32, @@ -217,6 +221,7 @@ class EnhancerDataset(TaskDataset): self.files = files self.duration = max(1.0, duration) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) + self.stride = stride or duration def setup(self, stage: Optional[str] = None): @@ -234,9 +239,22 @@ class EnhancerDataset(TaskDataset): weights=[file["duration"] for file in self.train_data], ) file_duration = file_dict["duration"] - start_time = round(rng.uniform(0, file_duration - self.duration), 2) - data = self.prepare_segment(file_dict, start_time) - yield data + num_segments = self.get_num_segments( + file_duration, self.duration, self.stride + ) + for index in range(0, num_segments): + start_time = index * self.stride + yield self.prepare_segment(file_dict, start_time) + + @staticmethod + def get_num_segments(file_duration, duration, stride): + + if file_duration < duration: + num_segments = 1 + else: + num_segments = math.ceil((file_duration - duration) / stride) + 1 + + return num_segments def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) @@ -273,8 +291,16 @@ class EnhancerDataset(TaskDataset): return {"clean": clean_segment, "noisy": noisy_segment} def train__len__(self): + return math.ceil( - sum([file["duration"] for file in self.train_data]) / self.duration + sum( + [ + self.get_num_segments( + file["duration"], self.duration, self.stride + ) + for file in self.train_data + ] + ) ) def val__len__(self):