default stride None

This commit is contained in:
shahules786 2022-10-24 21:15:25 +05:30
parent 75ebef2462
commit 5dc5fd8f90
1 changed files with 17 additions and 4 deletions

View File

@ -14,6 +14,9 @@ from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
# from torch_audiomentations import Compose
LARGE_NUM = 2147483647
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
# augmentations: Optional[Compose] = None,
):
super().__init__()
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
else:
raise ValueError("valid_minutes must be greater than 0")
# self.augmentations = augmentations
def setup(self, stage: Optional[str] = None):
"""
prepare train/validation/test data splits
@ -161,7 +167,9 @@ class TaskDataset(pl.LightningDataModule):
if total_dur < self.duration:
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
else:
num_segments = round(total_dur / self.duration)
num_segments = self.get_num_segments(
total_dur, self.duration, self.duration
)
for index in range(num_segments):
start_time = index * self.duration
metadata.append(
@ -175,8 +183,11 @@ class TaskDataset(pl.LightningDataModule):
@property
def generator(self):
generator = torch.Generator()
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM)
return generator
if hasattr(self, "model"):
seed = self.model.current_epoch + LARGE_NUM
else:
seed = LARGE_NUM
return generator.manual_seed(seed)
def train_dataloader(self):
return DataLoader(
@ -235,11 +246,12 @@ class EnhancerDataset(TaskDataset):
files: Files,
valid_minutes=5.0,
duration=1.0,
stride=0.5,
stride=None,
sampling_rate=48000,
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
# augmentations: Optional[Compose] = None,
):
super().__init__(
@ -252,6 +264,7 @@ class EnhancerDataset(TaskDataset):
matching_function=matching_function,
batch_size=batch_size,
num_workers=num_workers,
# augmentations=augmentations,
)
self.sampling_rate = sampling_rate