change to mapstyle
This commit is contained in:
parent
02192e5567
commit
40e2d6e0b0
|
|
@ -1,14 +1,12 @@
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
from itertools import chain, cycle
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
from enhancer.utils import check_files
|
from enhancer.utils import check_files
|
||||||
|
|
@ -16,13 +14,15 @@ from enhancer.utils.config import Files
|
||||||
from enhancer.utils.io import Audio
|
from enhancer.utils.io import Audio
|
||||||
from enhancer.utils.random import create_unique_rng
|
from enhancer.utils.random import create_unique_rng
|
||||||
|
|
||||||
|
LARGE_NUM = 2147483647
|
||||||
|
|
||||||
class TrainDataset(IterableDataset):
|
|
||||||
|
class TrainDataset(Dataset):
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
|
|
||||||
def __iter__(self):
|
def __getitem__(self, idx):
|
||||||
return self.dataset.train__iter__()
|
return self.dataset.train__getitem__(idx)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.dataset.train__len__()
|
return self.dataset.train__len__()
|
||||||
|
|
@ -135,16 +135,11 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
def prepare_traindata(self, data):
|
def prepare_traindata(self, data):
|
||||||
train_data = []
|
train_data = []
|
||||||
for item in data:
|
for item in data:
|
||||||
samples_metadata = []
|
|
||||||
clean, noisy, total_dur = item.values()
|
clean, noisy, total_dur = item.values()
|
||||||
num_segments = self.get_num_segments(
|
num_segments = self.get_num_segments(
|
||||||
total_dur, self.duration, self.stride
|
total_dur, self.duration, self.stride
|
||||||
)
|
)
|
||||||
for index in range(num_segments):
|
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
|
||||||
start = index * self.stride
|
|
||||||
samples_metadata.append(
|
|
||||||
({"clean": clean, "noisy": noisy}, start)
|
|
||||||
)
|
|
||||||
train_data.append(samples_metadata)
|
train_data.append(samples_metadata)
|
||||||
return train_data
|
return train_data
|
||||||
|
|
||||||
|
|
@ -175,31 +170,20 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def train_collatefn(self, batch):
|
def train_collatefn(self, batch):
|
||||||
output = {"noisy": [], "clean": []}
|
raise NotImplementedError("Not implemented")
|
||||||
for item in batch:
|
|
||||||
output["noisy"].append(item["noisy"])
|
|
||||||
output["clean"].append(item["clean"])
|
|
||||||
|
|
||||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
@property
|
||||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
def generator(self):
|
||||||
return output
|
generator = torch.Generator()
|
||||||
|
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM)
|
||||||
def worker_init_fn(self, _):
|
return generator
|
||||||
worker_info = torch.utils.data.get_worker_info()
|
|
||||||
dataset = worker_info.dataset
|
|
||||||
worker_id = worker_info.id
|
|
||||||
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
|
|
||||||
dataset.data = dataset.dataset.train_data[
|
|
||||||
worker_id * split_size : (worker_id + 1) * split_size
|
|
||||||
]
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
TrainDataset(self),
|
TrainDataset(self),
|
||||||
batch_size=None,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
collate_fn=self.train_collatefn,
|
generator=self.generator,
|
||||||
worker_init_fn=self.worker_init_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
|
|
@ -280,35 +264,16 @@ class EnhancerDataset(TaskDataset):
|
||||||
|
|
||||||
super().setup(stage=stage)
|
super().setup(stage=stage)
|
||||||
|
|
||||||
def random_sample(self, train_data):
|
def train__getitem__(self, idx):
|
||||||
return random.sample(train_data, len(train_data))
|
|
||||||
|
|
||||||
def train__iter__(self):
|
for filedict, num_samples in self.train_data:
|
||||||
rng = create_unique_rng(self.model.current_epoch)
|
if idx >= num_samples:
|
||||||
train_data = rng.sample(self.train_data, len(self.train_data))
|
idx -= num_samples
|
||||||
return zip(
|
continue
|
||||||
*[
|
start = 0
|
||||||
self.get_stream(self.random_sample(train_data))
|
if self.duration is not None:
|
||||||
for i in range(self.batch_size)
|
start = idx * self.stride
|
||||||
]
|
return self.prepare_segment(filedict, start)
|
||||||
)
|
|
||||||
|
|
||||||
def get_stream(self, data):
|
|
||||||
return chain.from_iterable(map(self.process_data, cycle(data)))
|
|
||||||
|
|
||||||
def process_data(self, data):
|
|
||||||
for item in data:
|
|
||||||
yield self.prepare_segment(*item)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_num_segments(file_duration, duration, stride):
|
|
||||||
|
|
||||||
if file_duration < duration:
|
|
||||||
num_segments = 1
|
|
||||||
else:
|
|
||||||
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
|
||||||
|
|
||||||
return num_segments
|
|
||||||
|
|
||||||
def val__getitem__(self, idx):
|
def val__getitem__(self, idx):
|
||||||
return self.prepare_segment(*self._validation[idx])
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
@ -348,7 +313,8 @@ class EnhancerDataset(TaskDataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
return sum([len(item) for item in self.train_data]) // (self.batch_size)
|
_, num_examples = list(zip(*self.train_data))
|
||||||
|
return sum(num_examples)
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue