default stride None
This commit is contained in:
parent
75ebef2462
commit
5dc5fd8f90
|
|
@ -14,6 +14,9 @@ from enhancer.utils.config import Files
|
|||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.random import create_unique_rng
|
||||
|
||||
# from torch_audiomentations import Compose
|
||||
|
||||
|
||||
LARGE_NUM = 2147483647
|
||||
|
||||
|
||||
|
|
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers: Optional[int] = None,
|
||||
# augmentations: Optional[Compose] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
|
|||
else:
|
||||
raise ValueError("valid_minutes must be greater than 0")
|
||||
|
||||
# self.augmentations = augmentations
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
"""
|
||||
prepare train/validation/test data splits
|
||||
|
|
@ -161,7 +167,9 @@ class TaskDataset(pl.LightningDataModule):
|
|||
if total_dur < self.duration:
|
||||
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
||||
else:
|
||||
num_segments = round(total_dur / self.duration)
|
||||
num_segments = self.get_num_segments(
|
||||
total_dur, self.duration, self.duration
|
||||
)
|
||||
for index in range(num_segments):
|
||||
start_time = index * self.duration
|
||||
metadata.append(
|
||||
|
|
@ -175,8 +183,11 @@ class TaskDataset(pl.LightningDataModule):
|
|||
@property
|
||||
def generator(self):
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM)
|
||||
return generator
|
||||
if hasattr(self, "model"):
|
||||
seed = self.model.current_epoch + LARGE_NUM
|
||||
else:
|
||||
seed = LARGE_NUM
|
||||
return generator.manual_seed(seed)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(
|
||||
|
|
@ -235,11 +246,12 @@ class EnhancerDataset(TaskDataset):
|
|||
files: Files,
|
||||
valid_minutes=5.0,
|
||||
duration=1.0,
|
||||
stride=0.5,
|
||||
stride=None,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers: Optional[int] = None,
|
||||
# augmentations: Optional[Compose] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
|
|
@ -252,6 +264,7 @@ class EnhancerDataset(TaskDataset):
|
|||
matching_function=matching_function,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
# augmentations=augmentations,
|
||||
)
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
|
|
|
|||
Loading…
Reference in New Issue