change to mapstyle

This commit is contained in:
shahules786 2022-10-23 12:32:58 +05:30
parent 02192e5567
commit 40e2d6e0b0
1 changed files with 26 additions and 60 deletions

View File

@ -1,14 +1,12 @@
import math
import multiprocessing
import os
import random
from itertools import chain, cycle
from typing import Optional
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.data import DataLoader, Dataset
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils import check_files
@ -16,13 +14,15 @@ from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
LARGE_NUM = 2147483647
class TrainDataset(IterableDataset):
class TrainDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
return self.dataset.train__iter__()
def __getitem__(self, idx):
return self.dataset.train__getitem__(idx)
def __len__(self):
return self.dataset.train__len__()
@ -135,16 +135,11 @@ class TaskDataset(pl.LightningDataModule):
def prepare_traindata(self, data):
train_data = []
for item in data:
samples_metadata = []
clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments(
total_dur, self.duration, self.stride
)
for index in range(num_segments):
start = index * self.stride
samples_metadata.append(
({"clean": clean, "noisy": noisy}, start)
)
samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
train_data.append(samples_metadata)
return train_data
@ -175,31 +170,20 @@ class TaskDataset(pl.LightningDataModule):
return metadata
def train_collatefn(self, batch):
output = {"noisy": [], "clean": []}
for item in batch:
output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"])
raise NotImplementedError("Not implemented")
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
return output
def worker_init_fn(self, _):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
dataset.data = dataset.dataset.train_data[
worker_id * split_size : (worker_id + 1) * split_size
]
@property
def generator(self):
generator = torch.Generator()
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM)
return generator
def train_dataloader(self):
return DataLoader(
TrainDataset(self),
batch_size=None,
batch_size=self.batch_size,
num_workers=self.num_workers,
collate_fn=self.train_collatefn,
worker_init_fn=self.worker_init_fn,
generator=self.generator,
)
def val_dataloader(self):
@ -280,35 +264,16 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage)
def random_sample(self, train_data):
return random.sample(train_data, len(train_data))
def train__getitem__(self, idx):
def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
train_data = rng.sample(self.train_data, len(self.train_data))
return zip(
*[
self.get_stream(self.random_sample(train_data))
for i in range(self.batch_size)
]
)
def get_stream(self, data):
return chain.from_iterable(map(self.process_data, cycle(data)))
def process_data(self, data):
for item in data:
yield self.prepare_segment(*item)
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
for filedict, num_samples in self.train_data:
if idx >= num_samples:
idx -= num_samples
continue
start = 0
if self.duration is not None:
start = idx * self.stride
return self.prepare_segment(filedict, start)
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
@ -348,7 +313,8 @@ class EnhancerDataset(TaskDataset):
}
def train__len__(self):
return sum([len(item) for item in self.train_data]) // (self.batch_size)
_, num_examples = list(zip(*self.train_data))
return sum(num_examples)
def val__len__(self):
return len(self._validation)