add random sampler
This commit is contained in:
parent
fb2543e81e
commit
aa52d1ed93
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
from torch_audiomentations import Compose
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
|
|
@ -133,6 +133,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
speaker_index = rng.choice(possible_indices)
|
||||
possible_indices.remove(speaker_index)
|
||||
speaker_name = all_speakers[speaker_index]
|
||||
print(f"Selected f{speaker_name} for valid")
|
||||
file_indices = [
|
||||
i
|
||||
for i, file in enumerate(data)
|
||||
|
|
@ -220,11 +221,13 @@ class TaskDataset(pl.LightningDataModule):
|
|||
return generator.manual_seed(seed)
|
||||
|
||||
def train_dataloader(self):
|
||||
dataset = TrainDataset(self)
|
||||
sampler = RandomSampler(dataset, generator=self.generator)
|
||||
return DataLoader(
|
||||
TrainDataset(self),
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
generator=self.generator,
|
||||
sampler=sampler,
|
||||
collate_fn=self.train_collatefn,
|
||||
)
|
||||
|
||||
|
|
@ -327,7 +330,6 @@ class EnhancerDataset(TaskDataset):
|
|||
return self.prepare_segment(*self._test[idx])
|
||||
|
||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||
|
||||
clean_segment = self.audio(
|
||||
file_dict["clean"], offset=start_time, duration=self.duration
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue