rmv padding_mode
This commit is contained in:
parent
6f1acf0423
commit
7f3dcf39c5
|
|
@ -8,7 +8,7 @@ import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
from torch_audiomentations import Compose
|
from torch_audiomentations import Compose
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
|
|
@ -135,6 +135,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
speaker_index = rng.choice(possible_indices)
|
speaker_index = rng.choice(possible_indices)
|
||||||
possible_indices.remove(speaker_index)
|
possible_indices.remove(speaker_index)
|
||||||
speaker_name = all_speakers[speaker_index]
|
speaker_name = all_speakers[speaker_index]
|
||||||
|
print(f"Selected f{speaker_name} for valid")
|
||||||
file_indices = [
|
file_indices = [
|
||||||
i
|
i
|
||||||
for i, file in enumerate(data)
|
for i, file in enumerate(data)
|
||||||
|
|
@ -222,11 +223,13 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
return generator.manual_seed(seed)
|
return generator.manual_seed(seed)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
|
dataset = TrainDataset(self)
|
||||||
|
sampler = RandomSampler(dataset, generator=self.generator)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
TrainDataset(self),
|
dataset,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
generator=self.generator,
|
sampler=sampler,
|
||||||
collate_fn=self.train_collatefn,
|
collate_fn=self.train_collatefn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -263,8 +266,6 @@ class EnhancerDataset(TaskDataset):
|
||||||
expected audio duration of single audio sample for training
|
expected audio duration of single audio sample for training
|
||||||
sampling_rate : int
|
sampling_rate : int
|
||||||
desired sampling rate
|
desired sampling rate
|
||||||
padding_mode: str
|
|
||||||
padding mode (silent,reflect)
|
|
||||||
batch_size : int
|
batch_size : int
|
||||||
batch size of each batch
|
batch size of each batch
|
||||||
num_workers : int
|
num_workers : int
|
||||||
|
|
@ -287,7 +288,6 @@ class EnhancerDataset(TaskDataset):
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
stride=None,
|
stride=None,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
padding_mode: str = "silent",
|
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
|
@ -312,7 +312,6 @@ class EnhancerDataset(TaskDataset):
|
||||||
self.duration = max(1.0, duration)
|
self.duration = max(1.0, duration)
|
||||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||||
self.stride = stride or duration
|
self.stride = stride or duration
|
||||||
self.padding_mode = padding_mode
|
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
|
@ -337,7 +336,6 @@ class EnhancerDataset(TaskDataset):
|
||||||
return self.prepare_segment(*self._test[idx])
|
return self.prepare_segment(*self._test[idx])
|
||||||
|
|
||||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||||
|
|
||||||
clean_segment = self.audio(
|
clean_segment = self.audio(
|
||||||
file_dict["clean"], offset=start_time, duration=self.duration
|
file_dict["clean"], offset=start_time, duration=self.duration
|
||||||
)
|
)
|
||||||
|
|
@ -362,7 +360,6 @@ class EnhancerDataset(TaskDataset):
|
||||||
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
mode=self.padding_mode,
|
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"clean": clean_segment,
|
"clean": clean_segment,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue