add padding
This commit is contained in:
parent
aa52d1ed93
commit
ad208ca0a0
|
|
@ -59,7 +59,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes: float = 0.20,
|
min_valid_minutes: float = 0.20,
|
||||||
duration: float = 1.0,
|
duration: float = 1.0,
|
||||||
stride=None,
|
stride=None,
|
||||||
sampling_rate: int = 48000,
|
sampling_rate: int = 48000,
|
||||||
|
|
@ -81,10 +81,10 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if num_workers is None:
|
if num_workers is None:
|
||||||
num_workers = multiprocessing.cpu_count() // 2
|
num_workers = multiprocessing.cpu_count() // 2
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
if valid_minutes > 0.0:
|
if min_valid_minutes > 0.0:
|
||||||
self.valid_minutes = valid_minutes
|
self.min_valid_minutes = min_valid_minutes
|
||||||
else:
|
else:
|
||||||
raise ValueError("valid_minutes must be greater than 0")
|
raise ValueError("min_valid_minutes must be greater than 0")
|
||||||
|
|
||||||
self.augmentations = augmentations
|
self.augmentations = augmentations
|
||||||
|
|
||||||
|
|
@ -102,7 +102,9 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
)
|
)
|
||||||
train_data = fp.prepare_matching_dict()
|
train_data = fp.prepare_matching_dict()
|
||||||
train_data, self.val_data = self.train_valid_split(
|
train_data, self.val_data = self.train_valid_split(
|
||||||
train_data, valid_minutes=self.valid_minutes, random_state=42
|
train_data,
|
||||||
|
min_valid_minutes=self.min_valid_minutes,
|
||||||
|
random_state=42,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.train_data = self.prepare_traindata(train_data)
|
self.train_data = self.prepare_traindata(train_data)
|
||||||
|
|
@ -117,10 +119,10 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self._test = self.prepare_mapstype(test_data)
|
self._test = self.prepare_mapstype(test_data)
|
||||||
|
|
||||||
def train_valid_split(
|
def train_valid_split(
|
||||||
self, data, valid_minutes: float = 20, random_state: int = 42
|
self, data, min_valid_minutes: float = 20, random_state: int = 42
|
||||||
):
|
):
|
||||||
|
|
||||||
valid_minutes *= 60
|
min_valid_minutes *= 60
|
||||||
valid_sec_now = 0.0
|
valid_sec_now = 0.0
|
||||||
valid_indices = []
|
valid_indices = []
|
||||||
all_speakers = np.unique(
|
all_speakers = np.unique(
|
||||||
|
|
@ -129,7 +131,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
possible_indices = list(range(0, len(all_speakers)))
|
possible_indices = list(range(0, len(all_speakers)))
|
||||||
rng = create_unique_rng(len(all_speakers))
|
rng = create_unique_rng(len(all_speakers))
|
||||||
|
|
||||||
while valid_sec_now <= valid_minutes:
|
while valid_sec_now <= min_valid_minutes:
|
||||||
speaker_index = rng.choice(possible_indices)
|
speaker_index = rng.choice(possible_indices)
|
||||||
possible_indices.remove(speaker_index)
|
possible_indices.remove(speaker_index)
|
||||||
speaker_name = all_speakers[speaker_index]
|
speaker_name = all_speakers[speaker_index]
|
||||||
|
|
@ -257,10 +259,15 @@ class EnhancerDataset(TaskDataset):
|
||||||
files : Files
|
files : Files
|
||||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||||
folder names (refer enhancer.utils.Files dataclass)
|
folder names (refer enhancer.utils.Files dataclass)
|
||||||
|
min_valid_minutes: float
|
||||||
|
minimum validation split size time in minutes
|
||||||
|
algorithm randomly select n speakers (>=min_valid_minutes) from train data to form validation data.
|
||||||
duration : float
|
duration : float
|
||||||
expected audio duration of single audio sample for training
|
expected audio duration of single audio sample for training
|
||||||
sampling_rate : int
|
sampling_rate : int
|
||||||
desired sampling rate
|
desired sampling rate
|
||||||
|
padding_mode: str
|
||||||
|
padding mode (silent,reflect)
|
||||||
batch_size : int
|
batch_size : int
|
||||||
batch size of each batch
|
batch size of each batch
|
||||||
num_workers : int
|
num_workers : int
|
||||||
|
|
@ -271,6 +278,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
use one_to_many mapping for multiple noisy files for each clean file
|
use one_to_many mapping for multiple noisy files for each clean file
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -278,10 +286,11 @@ class EnhancerDataset(TaskDataset):
|
||||||
name: str,
|
name: str,
|
||||||
root_dir: str,
|
root_dir: str,
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes=5.0,
|
min_valid_minutes=5.0,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
stride=None,
|
stride=None,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
|
padding_mode: str = "silent",
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
|
@ -292,7 +301,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
name=name,
|
name=name,
|
||||||
root_dir=root_dir,
|
root_dir=root_dir,
|
||||||
files=files,
|
files=files,
|
||||||
valid_minutes=valid_minutes,
|
min_valid_minutes=min_valid_minutes,
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
|
|
@ -306,6 +315,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
self.duration = max(1.0, duration)
|
self.duration = max(1.0, duration)
|
||||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||||
self.stride = stride or duration
|
self.stride = stride or duration
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
|
@ -344,6 +354,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
noisy_segment = F.pad(
|
noisy_segment = F.pad(
|
||||||
noisy_segment,
|
noisy_segment,
|
||||||
|
|
@ -353,6 +364,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
mode=self.padding_mode,
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"clean": clean_segment,
|
"clean": clean_segment,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue