rmv padding_mode

This commit is contained in:
shahules786 2022-10-29 10:39:32 +05:30
parent 6f1acf0423
commit 7f3dcf39c5
1 changed files with 6 additions and 9 deletions

View File

@ -8,7 +8,7 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch_audiomentations import Compose from torch_audiomentations import Compose
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
@ -135,6 +135,7 @@ class TaskDataset(pl.LightningDataModule):
speaker_index = rng.choice(possible_indices) speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index) possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index] speaker_name = all_speakers[speaker_index]
print(f"Selected f{speaker_name} for valid")
file_indices = [ file_indices = [
i i
for i, file in enumerate(data) for i, file in enumerate(data)
@ -222,11 +223,13 @@ class TaskDataset(pl.LightningDataModule):
return generator.manual_seed(seed) return generator.manual_seed(seed)
def train_dataloader(self): def train_dataloader(self):
dataset = TrainDataset(self)
sampler = RandomSampler(dataset, generator=self.generator)
return DataLoader( return DataLoader(
TrainDataset(self), dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
num_workers=self.num_workers, num_workers=self.num_workers,
generator=self.generator, sampler=sampler,
collate_fn=self.train_collatefn, collate_fn=self.train_collatefn,
) )
@ -263,8 +266,6 @@ class EnhancerDataset(TaskDataset):
expected audio duration of single audio sample for training expected audio duration of single audio sample for training
sampling_rate : int sampling_rate : int
desired sampling rate desired sampling rate
padding_mode: str
padding mode (silent,reflect)
batch_size : int batch_size : int
batch size of each batch batch size of each batch
num_workers : int num_workers : int
@ -287,7 +288,6 @@ class EnhancerDataset(TaskDataset):
duration=1.0, duration=1.0,
stride=None, stride=None,
sampling_rate=48000, sampling_rate=48000,
padding_mode: str = "silent",
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
@ -312,7 +312,6 @@ class EnhancerDataset(TaskDataset):
self.duration = max(1.0, duration) self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
self.stride = stride or duration self.stride = stride or duration
self.padding_mode = padding_mode
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
@ -337,7 +336,6 @@ class EnhancerDataset(TaskDataset):
return self.prepare_segment(*self._test[idx]) return self.prepare_segment(*self._test[idx])
def prepare_segment(self, file_dict: dict, start_time: float): def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio( clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration file_dict["clean"], offset=start_time, duration=self.duration
) )
@ -362,7 +360,6 @@ class EnhancerDataset(TaskDataset):
self.duration * self.sampling_rate - noisy_segment.shape[-1] self.duration * self.sampling_rate - noisy_segment.shape[-1]
), ),
), ),
mode=self.padding_mode,
) )
return { return {
"clean": clean_segment, "clean": clean_segment,