add padding

This commit is contained in:
shahules786 2022-10-29 09:41:56 +05:30
parent aa52d1ed93
commit ad208ca0a0
1 changed files with 22 additions and 10 deletions

View File

@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule):
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_minutes: float = 0.20, min_valid_minutes: float = 0.20,
duration: float = 1.0, duration: float = 1.0,
stride=None, stride=None,
sampling_rate: int = 48000, sampling_rate: int = 48000,
@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
if valid_minutes > 0.0: if min_valid_minutes > 0.0:
self.valid_minutes = valid_minutes self.min_valid_minutes = min_valid_minutes
else: else:
raise ValueError("valid_minutes must be greater than 0") raise ValueError("min_valid_minutes must be greater than 0")
self.augmentations = augmentations self.augmentations = augmentations
@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule):
) )
train_data = fp.prepare_matching_dict() train_data = fp.prepare_matching_dict()
train_data, self.val_data = self.train_valid_split( train_data, self.val_data = self.train_valid_split(
train_data, valid_minutes=self.valid_minutes, random_state=42 train_data,
min_valid_minutes=self.min_valid_minutes,
random_state=42,
) )
self.train_data = self.prepare_traindata(train_data) self.train_data = self.prepare_traindata(train_data)
@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule):
self._test = self.prepare_mapstype(test_data) self._test = self.prepare_mapstype(test_data)
def train_valid_split( def train_valid_split(
self, data, valid_minutes: float = 20, random_state: int = 42 self, data, min_valid_minutes: float = 20, random_state: int = 42
): ):
valid_minutes *= 60 min_valid_minutes *= 60
valid_sec_now = 0.0 valid_sec_now = 0.0
valid_indices = [] valid_indices = []
all_speakers = np.unique( all_speakers = np.unique(
@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule):
possible_indices = list(range(0, len(all_speakers))) possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers)) rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes: while valid_sec_now <= min_valid_minutes:
speaker_index = rng.choice(possible_indices) speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index) possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index] speaker_name = all_speakers[speaker_index]
@ -257,10 +259,15 @@ class EnhancerDataset(TaskDataset):
files : Files files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass) folder names (refer enhancer.utils.Files dataclass)
min_valid_minutes: float
minimum validation split size time in minutes
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
duration : float duration : float
expected audio duration of single audio sample for training expected audio duration of single audio sample for training
sampling_rate : int sampling_rate : int
desired sampling rate desired sampling rate
padding_mode: str
padding mode (silent,reflect)
batch_size : int batch_size : int
batch size of each batch batch size of each batch
num_workers : int num_workers : int
@ -271,6 +278,7 @@ class EnhancerDataset(TaskDataset):
use one_to_many mapping for multiple noisy files for each clean file use one_to_many mapping for multiple noisy files for each clean file
""" """
def __init__( def __init__(
@ -278,10 +286,11 @@ class EnhancerDataset(TaskDataset):
name: str, name: str,
root_dir: str, root_dir: str,
files: Files, files: Files,
valid_minutes=5.0, min_valid_minutes=5.0,
duration=1.0, duration=1.0,
stride=None, stride=None,
sampling_rate=48000, sampling_rate=48000,
padding_mode: str = "silent",
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
@ -292,7 +301,7 @@ class EnhancerDataset(TaskDataset):
name=name, name=name,
root_dir=root_dir, root_dir=root_dir,
files=files, files=files,
valid_minutes=valid_minutes, min_valid_minutes=min_valid_minutes,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
duration=duration, duration=duration,
matching_function=matching_function, matching_function=matching_function,
@ -306,6 +315,7 @@ class EnhancerDataset(TaskDataset):
self.duration = max(1.0, duration) self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
self.stride = stride or duration self.stride = stride or duration
self.padding_mode = padding_mode
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
@ -344,6 +354,7 @@ class EnhancerDataset(TaskDataset):
self.duration * self.sampling_rate - clean_segment.shape[-1] self.duration * self.sampling_rate - clean_segment.shape[-1]
), ),
), ),
mode=self.padding_mode,
) )
noisy_segment = F.pad( noisy_segment = F.pad(
noisy_segment, noisy_segment,
@ -353,6 +364,7 @@ class EnhancerDataset(TaskDataset):
self.duration * self.sampling_rate - noisy_segment.shape[-1] self.duration * self.sampling_rate - noisy_segment.shape[-1]
), ),
), ),
mode=self.padding_mode,
) )
return { return {
"clean": clean_segment, "clean": clean_segment,