debug iterative dataset

This commit is contained in:
shahules786 2022-10-20 09:49:27 +05:30
parent e4f13946e8
commit 2ad49faa67
1 changed files with 74 additions and 28 deletions

View File

@ -1,9 +1,11 @@
import math
import multiprocessing
import os
from itertools import chain, cycle
from typing import Optional
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
@ -55,6 +57,7 @@ class TaskDataset(pl.LightningDataModule):
files: Files,
valid_minutes: float = 0.20,
duration: float = 1.0,
stride=None,
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
@ -65,6 +68,7 @@ class TaskDataset(pl.LightningDataModule):
self.name = name
self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration
self.stride = stride or duration
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
@ -90,10 +94,11 @@ class TaskDataset(pl.LightningDataModule):
self.name, train_clean, train_noisy, self.matching_function
)
train_data = fp.prepare_matching_dict()
self.train_data, self.val_data = self.train_valid_split(
train_data, self.val_data = self.train_valid_split(
train_data, valid_minutes=self.valid_minutes, random_state=42
)
self.train_data = self.prepare_traindata(train_data)
self._validation = self.prepare_mapstype(self.val_data)
test_clean = os.path.join(self.root_dir, self.files.test_clean)
@ -112,7 +117,7 @@ class TaskDataset(pl.LightningDataModule):
valid_min_now = 0.0
valid_indices = []
random_indices = list(range(0, len(data)))
rng = create_unique_rng(random_state)
rng = create_unique_rng(random_state, 0)
rng.shuffle(random_indices)
i = 0
while valid_min_now <= valid_minutes:
@ -126,6 +131,33 @@ class TaskDataset(pl.LightningDataModule):
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
return train_data, valid_data
def prepare_traindata(self, data):
train_data = []
for item in data:
samples_metadata = []
clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments(
total_dur, self.duration, self.stride
)
for index in range(num_segments):
start = index * self.stride
samples_metadata.append(
({"clean": clean, "noisy": noisy}, start)
)
train_data.append(samples_metadata)
print(train_data[:10])
return train_data
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
def prepare_mapstype(self, data):
metadata = []
@ -142,11 +174,33 @@ class TaskDataset(pl.LightningDataModule):
)
return metadata
def train_collatefn(self, batch):
output = {"noisy": [], "clean": []}
for item in batch:
output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"])
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
return output
def worker_init_fn(self):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
split_size = len(dataset.data) // worker_info.num_workers
dataset.data = dataset.data[
worker_id * split_size : (worker_id + 1) * split_size
]
def train_dataloader(self):
return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
batch_size=None,
num_workers=self.num_workers,
collate_fn=self.train_collatefn,
worker_init_fn=self.worker_init_fn,
)
def val_dataloader(self):
@ -227,24 +281,24 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage)
def random_sample(self, index):
rng = create_unique_rng(self.model.current_epoch, index)
return rng.sample(self.train_data, len(self.train_data))
def train__iter__(self):
return zip(
*[
self.get_stream(self.random_sample(i))
for i in range(self.batch_size)
]
)
rng = create_unique_rng(self.model.current_epoch)
def get_stream(self, data):
return chain.from_iterable(map(self.process_data, cycle(data)))
while True:
file_dict, *_ = rng.choices(
self.train_data,
k=1,
weights=[file["duration"] for file in self.train_data],
)
file_duration = file_dict["duration"]
num_segments = self.get_num_segments(
file_duration, self.duration, self.stride
)
for index in range(0, num_segments):
start_time = index * self.stride
yield self.prepare_segment(file_dict, start_time)
def process_data(self, data):
for item in data:
yield self.prepare_segment(*item)
@staticmethod
def get_num_segments(file_duration, duration, stride):
@ -264,6 +318,7 @@ class EnhancerDataset(TaskDataset):
def prepare_segment(self, file_dict: dict, start_time: float):
print(file_dict["clean"].split("/")[-1], "->", start_time)
clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration
)
@ -292,16 +347,7 @@ class EnhancerDataset(TaskDataset):
def train__len__(self):
return math.ceil(
sum(
[
self.get_num_segments(
file["duration"], self.duration, self.stride
)
for file in self.train_data
]
)
)
return sum([len(item) for item in self.train_data])
def val__len__(self):
return len(self._validation)