add padding

This commit is contained in:
shahules786 2022-10-29 09:41:56 +05:30
parent aa52d1ed93
commit ad208ca0a0
1 changed files with 22 additions and 10 deletions

View File

@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule):
name: str,
root_dir: str,
files: Files,
valid_minutes: float = 0.20,
min_valid_minutes: float = 0.20,
duration: float = 1.0,
stride=None,
sampling_rate: int = 48000,
@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule):
if num_workers is None:
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
if valid_minutes > 0.0:
self.valid_minutes = valid_minutes
if min_valid_minutes > 0.0:
self.min_valid_minutes = min_valid_minutes
else:
raise ValueError("valid_minutes must be greater than 0")
raise ValueError("min_valid_minutes must be greater than 0")
self.augmentations = augmentations
@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule):
)
train_data = fp.prepare_matching_dict()
train_data, self.val_data = self.train_valid_split(
train_data, valid_minutes=self.valid_minutes, random_state=42
train_data,
min_valid_minutes=self.min_valid_minutes,
random_state=42,
)
self.train_data = self.prepare_traindata(train_data)
@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule):
self._test = self.prepare_mapstype(test_data)
def train_valid_split(
self, data, valid_minutes: float = 20, random_state: int = 42
self, data, min_valid_minutes: float = 20, random_state: int = 42
):
valid_minutes *= 60
min_valid_minutes *= 60
valid_sec_now = 0.0
valid_indices = []
all_speakers = np.unique(
@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule):
possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes:
while valid_sec_now <= min_valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
@ -257,10 +259,15 @@ class EnhancerDataset(TaskDataset):
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass)
min_valid_minutes: float
minimum validation split size time in minutes
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
desired sampling rate
padding_mode: str
padding mode (silent,reflect)
batch_size : int
batch size of each batch
num_workers : int
@ -271,6 +278,7 @@ class EnhancerDataset(TaskDataset):
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
@ -278,10 +286,11 @@ class EnhancerDataset(TaskDataset):
name: str,
root_dir: str,
files: Files,
valid_minutes=5.0,
min_valid_minutes=5.0,
duration=1.0,
stride=None,
sampling_rate=48000,
padding_mode: str = "silent",
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
@ -292,7 +301,7 @@ class EnhancerDataset(TaskDataset):
name=name,
root_dir=root_dir,
files=files,
valid_minutes=valid_minutes,
min_valid_minutes=min_valid_minutes,
sampling_rate=sampling_rate,
duration=duration,
matching_function=matching_function,
@ -306,6 +315,7 @@ class EnhancerDataset(TaskDataset):
self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
self.stride = stride or duration
self.padding_mode = padding_mode
def setup(self, stage: Optional[str] = None):
@ -344,6 +354,7 @@ class EnhancerDataset(TaskDataset):
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
mode=self.padding_mode,
)
noisy_segment = F.pad(
noisy_segment,
@ -353,6 +364,7 @@ class EnhancerDataset(TaskDataset):
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
mode=self.padding_mode,
)
return {
"clean": clean_segment,