debug iterative dataset
This commit is contained in:
parent
e4f13946e8
commit
2ad49faa67
|
|
@ -1,9 +1,11 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
from itertools import chain, cycle
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
|
||||
|
|
@ -55,6 +57,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
files: Files,
|
||||
valid_minutes: float = 0.20,
|
||||
duration: float = 1.0,
|
||||
stride=None,
|
||||
sampling_rate: int = 48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
|
|
@ -65,6 +68,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self.name = name
|
||||
self.files, self.root_dir = check_files(root_dir, files)
|
||||
self.duration = duration
|
||||
self.stride = stride or duration
|
||||
self.sampling_rate = sampling_rate
|
||||
self.batch_size = batch_size
|
||||
self.matching_function = matching_function
|
||||
|
|
@ -90,10 +94,11 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self.name, train_clean, train_noisy, self.matching_function
|
||||
)
|
||||
train_data = fp.prepare_matching_dict()
|
||||
self.train_data, self.val_data = self.train_valid_split(
|
||||
train_data, self.val_data = self.train_valid_split(
|
||||
train_data, valid_minutes=self.valid_minutes, random_state=42
|
||||
)
|
||||
|
||||
self.train_data = self.prepare_traindata(train_data)
|
||||
self._validation = self.prepare_mapstype(self.val_data)
|
||||
|
||||
test_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||
|
|
@ -112,7 +117,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
valid_min_now = 0.0
|
||||
valid_indices = []
|
||||
random_indices = list(range(0, len(data)))
|
||||
rng = create_unique_rng(random_state)
|
||||
rng = create_unique_rng(random_state, 0)
|
||||
rng.shuffle(random_indices)
|
||||
i = 0
|
||||
while valid_min_now <= valid_minutes:
|
||||
|
|
@ -126,6 +131,33 @@ class TaskDataset(pl.LightningDataModule):
|
|||
valid_data = [item for i, item in enumerate(data) if i in valid_indices]
|
||||
return train_data, valid_data
|
||||
|
||||
def prepare_traindata(self, data):
|
||||
train_data = []
|
||||
for item in data:
|
||||
samples_metadata = []
|
||||
clean, noisy, total_dur = item.values()
|
||||
num_segments = self.get_num_segments(
|
||||
total_dur, self.duration, self.stride
|
||||
)
|
||||
for index in range(num_segments):
|
||||
start = index * self.stride
|
||||
samples_metadata.append(
|
||||
({"clean": clean, "noisy": noisy}, start)
|
||||
)
|
||||
train_data.append(samples_metadata)
|
||||
print(train_data[:10])
|
||||
return train_data
|
||||
|
||||
@staticmethod
|
||||
def get_num_segments(file_duration, duration, stride):
|
||||
|
||||
if file_duration < duration:
|
||||
num_segments = 1
|
||||
else:
|
||||
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
||||
|
||||
return num_segments
|
||||
|
||||
def prepare_mapstype(self, data):
|
||||
|
||||
metadata = []
|
||||
|
|
@ -142,11 +174,33 @@ class TaskDataset(pl.LightningDataModule):
|
|||
)
|
||||
return metadata
|
||||
|
||||
def train_collatefn(self, batch):
|
||||
|
||||
output = {"noisy": [], "clean": []}
|
||||
for item in batch:
|
||||
output["noisy"].append(item["noisy"])
|
||||
output["clean"].append(item["clean"])
|
||||
|
||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||
return output
|
||||
|
||||
def worker_init_fn(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
split_size = len(dataset.data) // worker_info.num_workers
|
||||
dataset.data = dataset.data[
|
||||
worker_id * split_size : (worker_id + 1) * split_size
|
||||
]
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
TrainDataset(self),
|
||||
batch_size=self.batch_size,
|
||||
batch_size=None,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=self.train_collatefn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
|
|
@ -227,24 +281,24 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
super().setup(stage=stage)
|
||||
|
||||
def random_sample(self, index):
|
||||
rng = create_unique_rng(self.model.current_epoch, index)
|
||||
return rng.sample(self.train_data, len(self.train_data))
|
||||
|
||||
def train__iter__(self):
|
||||
return zip(
|
||||
*[
|
||||
self.get_stream(self.random_sample(i))
|
||||
for i in range(self.batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
rng = create_unique_rng(self.model.current_epoch)
|
||||
def get_stream(self, data):
|
||||
return chain.from_iterable(map(self.process_data, cycle(data)))
|
||||
|
||||
while True:
|
||||
|
||||
file_dict, *_ = rng.choices(
|
||||
self.train_data,
|
||||
k=1,
|
||||
weights=[file["duration"] for file in self.train_data],
|
||||
)
|
||||
file_duration = file_dict["duration"]
|
||||
num_segments = self.get_num_segments(
|
||||
file_duration, self.duration, self.stride
|
||||
)
|
||||
for index in range(0, num_segments):
|
||||
start_time = index * self.stride
|
||||
yield self.prepare_segment(file_dict, start_time)
|
||||
def process_data(self, data):
|
||||
for item in data:
|
||||
yield self.prepare_segment(*item)
|
||||
|
||||
@staticmethod
|
||||
def get_num_segments(file_duration, duration, stride):
|
||||
|
|
@ -264,6 +318,7 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||
|
||||
print(file_dict["clean"].split("/")[-1], "->", start_time)
|
||||
clean_segment = self.audio(
|
||||
file_dict["clean"], offset=start_time, duration=self.duration
|
||||
)
|
||||
|
|
@ -292,16 +347,7 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
def train__len__(self):
|
||||
|
||||
return math.ceil(
|
||||
sum(
|
||||
[
|
||||
self.get_num_segments(
|
||||
file["duration"], self.duration, self.stride
|
||||
)
|
||||
for file in self.train_data
|
||||
]
|
||||
)
|
||||
)
|
||||
return sum([len(item) for item in self.train_data])
|
||||
|
||||
def val__len__(self):
|
||||
return len(self._validation)
|
||||
|
|
|
|||
Loading…
Reference in New Issue