add padding
This commit is contained in:
parent
aa52d1ed93
commit
ad208ca0a0
|
|
@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
valid_minutes: float = 0.20,
|
||||
min_valid_minutes: float = 0.20,
|
||||
duration: float = 1.0,
|
||||
stride=None,
|
||||
sampling_rate: int = 48000,
|
||||
|
|
@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule):
|
|||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
self.num_workers = num_workers
|
||||
if valid_minutes > 0.0:
|
||||
self.valid_minutes = valid_minutes
|
||||
if min_valid_minutes > 0.0:
|
||||
self.min_valid_minutes = min_valid_minutes
|
||||
else:
|
||||
raise ValueError("valid_minutes must be greater than 0")
|
||||
raise ValueError("min_valid_minutes must be greater than 0")
|
||||
|
||||
self.augmentations = augmentations
|
||||
|
||||
|
|
@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule):
|
|||
)
|
||||
train_data = fp.prepare_matching_dict()
|
||||
train_data, self.val_data = self.train_valid_split(
|
||||
train_data, valid_minutes=self.valid_minutes, random_state=42
|
||||
train_data,
|
||||
min_valid_minutes=self.min_valid_minutes,
|
||||
random_state=42,
|
||||
)
|
||||
|
||||
self.train_data = self.prepare_traindata(train_data)
|
||||
|
|
@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self._test = self.prepare_mapstype(test_data)
|
||||
|
||||
def train_valid_split(
|
||||
self, data, valid_minutes: float = 20, random_state: int = 42
|
||||
self, data, min_valid_minutes: float = 20, random_state: int = 42
|
||||
):
|
||||
|
||||
valid_minutes *= 60
|
||||
min_valid_minutes *= 60
|
||||
valid_sec_now = 0.0
|
||||
valid_indices = []
|
||||
all_speakers = np.unique(
|
||||
|
|
@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
possible_indices = list(range(0, len(all_speakers)))
|
||||
rng = create_unique_rng(len(all_speakers))
|
||||
|
||||
while valid_sec_now <= valid_minutes:
|
||||
while valid_sec_now <= min_valid_minutes:
|
||||
speaker_index = rng.choice(possible_indices)
|
||||
possible_indices.remove(speaker_index)
|
||||
speaker_name = all_speakers[speaker_index]
|
||||
|
|
@ -257,10 +259,15 @@ class EnhancerDataset(TaskDataset):
|
|||
files : Files
|
||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||
folder names (refer enhancer.utils.Files dataclass)
|
||||
min_valid_minutes: float
|
||||
minimum validation split size time in minutes
|
||||
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
||||
duration : float
|
||||
expected audio duration of single audio sample for training
|
||||
sampling_rate : int
|
||||
desired sampling rate
|
||||
padding_mode: str
|
||||
padding mode (silent,reflect)
|
||||
batch_size : int
|
||||
batch size of each batch
|
||||
num_workers : int
|
||||
|
|
@ -271,6 +278,7 @@ class EnhancerDataset(TaskDataset):
|
|||
use one_to_many mapping for multiple noisy files for each clean file
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -278,10 +286,11 @@ class EnhancerDataset(TaskDataset):
|
|||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
valid_minutes=5.0,
|
||||
min_valid_minutes=5.0,
|
||||
duration=1.0,
|
||||
stride=None,
|
||||
sampling_rate=48000,
|
||||
padding_mode: str = "silent",
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers: Optional[int] = None,
|
||||
|
|
@ -292,7 +301,7 @@ class EnhancerDataset(TaskDataset):
|
|||
name=name,
|
||||
root_dir=root_dir,
|
||||
files=files,
|
||||
valid_minutes=valid_minutes,
|
||||
min_valid_minutes=min_valid_minutes,
|
||||
sampling_rate=sampling_rate,
|
||||
duration=duration,
|
||||
matching_function=matching_function,
|
||||
|
|
@ -306,6 +315,7 @@ class EnhancerDataset(TaskDataset):
|
|||
self.duration = max(1.0, duration)
|
||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||
self.stride = stride or duration
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
|
|
@ -344,6 +354,7 @@ class EnhancerDataset(TaskDataset):
|
|||
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||
),
|
||||
),
|
||||
mode=self.padding_mode,
|
||||
)
|
||||
noisy_segment = F.pad(
|
||||
noisy_segment,
|
||||
|
|
@ -353,6 +364,7 @@ class EnhancerDataset(TaskDataset):
|
|||
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||
),
|
||||
),
|
||||
mode=self.padding_mode,
|
||||
)
|
||||
return {
|
||||
"clean": clean_segment,
|
||||
|
|
|
|||
Loading…
Reference in New Issue