change to mapstyle
This commit is contained in:
parent
02192e5567
commit
40e2d6e0b0
|
|
@ -1,14 +1,12 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import random
|
||||
from itertools import chain, cycle
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
from enhancer.utils import check_files
|
||||
|
|
@ -16,13 +14,15 @@ from enhancer.utils.config import Files
|
|||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.random import create_unique_rng
|
||||
|
||||
LARGE_NUM = 2147483647
|
||||
|
||||
class TrainDataset(IterableDataset):
|
||||
|
||||
class TrainDataset(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __iter__(self):
|
||||
return self.dataset.train__iter__()
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset.train__getitem__(idx)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset.train__len__()
|
||||
|
|
@ -135,16 +135,11 @@ class TaskDataset(pl.LightningDataModule):
|
|||
def prepare_traindata(self, data):
|
||||
train_data = []
|
||||
for item in data:
|
||||
samples_metadata = []
|
||||
clean, noisy, total_dur = item.values()
|
||||
num_segments = self.get_num_segments(
|
||||
total_dur, self.duration, self.stride
|
||||
)
|
||||
for index in range(num_segments):
|
||||
start = index * self.stride
|
||||
samples_metadata.append(
|
||||
({"clean": clean, "noisy": noisy}, start)
|
||||
)
|
||||
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
|
||||
train_data.append(samples_metadata)
|
||||
return train_data
|
||||
|
||||
|
|
@ -175,31 +170,20 @@ class TaskDataset(pl.LightningDataModule):
|
|||
return metadata
|
||||
|
||||
def train_collatefn(self, batch):
|
||||
output = {"noisy": [], "clean": []}
|
||||
for item in batch:
|
||||
output["noisy"].append(item["noisy"])
|
||||
output["clean"].append(item["clean"])
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||
return output
|
||||
|
||||
def worker_init_fn(self, _):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
|
||||
dataset.data = dataset.dataset.train_data[
|
||||
worker_id * split_size : (worker_id + 1) * split_size
|
||||
]
|
||||
@property
|
||||
def generator(self):
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM)
|
||||
return generator
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
TrainDataset(self),
|
||||
batch_size=None,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
collate_fn=self.train_collatefn,
|
||||
worker_init_fn=self.worker_init_fn,
|
||||
generator=self.generator,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
|
|
@ -280,35 +264,16 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
super().setup(stage=stage)
|
||||
|
||||
def random_sample(self, train_data):
|
||||
return random.sample(train_data, len(train_data))
|
||||
def train__getitem__(self, idx):
|
||||
|
||||
def train__iter__(self):
|
||||
rng = create_unique_rng(self.model.current_epoch)
|
||||
train_data = rng.sample(self.train_data, len(self.train_data))
|
||||
return zip(
|
||||
*[
|
||||
self.get_stream(self.random_sample(train_data))
|
||||
for i in range(self.batch_size)
|
||||
]
|
||||
)
|
||||
|
||||
def get_stream(self, data):
|
||||
return chain.from_iterable(map(self.process_data, cycle(data)))
|
||||
|
||||
def process_data(self, data):
|
||||
for item in data:
|
||||
yield self.prepare_segment(*item)
|
||||
|
||||
@staticmethod
|
||||
def get_num_segments(file_duration, duration, stride):
|
||||
|
||||
if file_duration < duration:
|
||||
num_segments = 1
|
||||
else:
|
||||
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
||||
|
||||
return num_segments
|
||||
for filedict, num_samples in self.train_data:
|
||||
if idx >= num_samples:
|
||||
idx -= num_samples
|
||||
continue
|
||||
start = 0
|
||||
if self.duration is not None:
|
||||
start = idx * self.stride
|
||||
return self.prepare_segment(filedict, start)
|
||||
|
||||
def val__getitem__(self, idx):
|
||||
return self.prepare_segment(*self._validation[idx])
|
||||
|
|
@ -348,7 +313,8 @@ class EnhancerDataset(TaskDataset):
|
|||
}
|
||||
|
||||
def train__len__(self):
|
||||
return sum([len(item) for item in self.train_data]) // (self.batch_size)
|
||||
_, num_examples = list(zip(*self.train_data))
|
||||
return sum(num_examples)
|
||||
|
||||
def val__len__(self):
|
||||
return len(self._validation)
|
||||
|
|
|
|||
Loading…
Reference in New Issue