change to mapstyle

This commit is contained in:
shahules786 2022-10-23 12:32:58 +05:30
parent 02192e5567
commit 40e2d6e0b0
1 changed files with 26 additions and 60 deletions

View File

@ -1,14 +1,12 @@
import math import math
import multiprocessing import multiprocessing
import os import os
import random
from itertools import chain, cycle
from typing import Optional from typing import Optional
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset from torch.utils.data import DataLoader, Dataset
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils import check_files from enhancer.utils import check_files
@ -16,13 +14,15 @@ from enhancer.utils.config import Files
from enhancer.utils.io import Audio from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng from enhancer.utils.random import create_unique_rng
LARGE_NUM = 2147483647
class TrainDataset(IterableDataset):
class TrainDataset(Dataset):
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
def __iter__(self): def __getitem__(self, idx):
return self.dataset.train__iter__() return self.dataset.train__getitem__(idx)
def __len__(self): def __len__(self):
return self.dataset.train__len__() return self.dataset.train__len__()
@ -135,16 +135,11 @@ class TaskDataset(pl.LightningDataModule):
def prepare_traindata(self, data): def prepare_traindata(self, data):
train_data = [] train_data = []
for item in data: for item in data:
samples_metadata = []
clean, noisy, total_dur = item.values() clean, noisy, total_dur = item.values()
num_segments = self.get_num_segments( num_segments = self.get_num_segments(
total_dur, self.duration, self.stride total_dur, self.duration, self.stride
) )
for index in range(num_segments): samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments)
start = index * self.stride
samples_metadata.append(
({"clean": clean, "noisy": noisy}, start)
)
train_data.append(samples_metadata) train_data.append(samples_metadata)
return train_data return train_data
@ -175,31 +170,20 @@ class TaskDataset(pl.LightningDataModule):
return metadata return metadata
def train_collatefn(self, batch): def train_collatefn(self, batch):
output = {"noisy": [], "clean": []} raise NotImplementedError("Not implemented")
for item in batch:
output["noisy"].append(item["noisy"])
output["clean"].append(item["clean"])
output["clean"] = torch.stack(output["clean"], dim=0) @property
output["noisy"] = torch.stack(output["noisy"], dim=0) def generator(self):
return output generator = torch.Generator()
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM)
def worker_init_fn(self, _): return generator
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
split_size = len(dataset.dataset.train_data) // worker_info.num_workers
dataset.data = dataset.dataset.train_data[
worker_id * split_size : (worker_id + 1) * split_size
]
def train_dataloader(self): def train_dataloader(self):
return DataLoader( return DataLoader(
TrainDataset(self), TrainDataset(self),
batch_size=None, batch_size=self.batch_size,
num_workers=self.num_workers, num_workers=self.num_workers,
collate_fn=self.train_collatefn, generator=self.generator,
worker_init_fn=self.worker_init_fn,
) )
def val_dataloader(self): def val_dataloader(self):
@ -280,35 +264,16 @@ class EnhancerDataset(TaskDataset):
super().setup(stage=stage) super().setup(stage=stage)
def random_sample(self, train_data): def train__getitem__(self, idx):
return random.sample(train_data, len(train_data))
def train__iter__(self): for filedict, num_samples in self.train_data:
rng = create_unique_rng(self.model.current_epoch) if idx >= num_samples:
train_data = rng.sample(self.train_data, len(self.train_data)) idx -= num_samples
return zip( continue
*[ start = 0
self.get_stream(self.random_sample(train_data)) if self.duration is not None:
for i in range(self.batch_size) start = idx * self.stride
] return self.prepare_segment(filedict, start)
)
def get_stream(self, data):
return chain.from_iterable(map(self.process_data, cycle(data)))
def process_data(self, data):
for item in data:
yield self.prepare_segment(*item)
@staticmethod
def get_num_segments(file_duration, duration, stride):
if file_duration < duration:
num_segments = 1
else:
num_segments = math.ceil((file_duration - duration) / stride) + 1
return num_segments
def val__getitem__(self, idx): def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx]) return self.prepare_segment(*self._validation[idx])
@ -348,7 +313,8 @@ class EnhancerDataset(TaskDataset):
} }
def train__len__(self): def train__len__(self):
return sum([len(item) for item in self.train_data]) // (self.batch_size) _, num_examples = list(zip(*self.train_data))
return sum(num_examples)
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)