default stride None

This commit is contained in:
shahules786 2022-10-24 21:15:25 +05:30
parent 75ebef2462
commit 5dc5fd8f90
1 changed files with 17 additions and 4 deletions

View File

@ -14,6 +14,9 @@ from enhancer.utils.config import Files
from enhancer.utils.io import Audio from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng from enhancer.utils.random import create_unique_rng
# from torch_audiomentations import Compose
LARGE_NUM = 2147483647 LARGE_NUM = 2147483647
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
# augmentations: Optional[Compose] = None,
): ):
super().__init__() super().__init__()
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
else: else:
raise ValueError("valid_minutes must be greater than 0") raise ValueError("valid_minutes must be greater than 0")
# self.augmentations = augmentations
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
""" """
prepare train/validation/test data splits prepare train/validation/test data splits
@ -161,7 +167,9 @@ class TaskDataset(pl.LightningDataModule):
if total_dur < self.duration: if total_dur < self.duration:
metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
else: else:
num_segments = round(total_dur / self.duration) num_segments = self.get_num_segments(
total_dur, self.duration, self.duration
)
for index in range(num_segments): for index in range(num_segments):
start_time = index * self.duration start_time = index * self.duration
metadata.append( metadata.append(
@ -175,8 +183,11 @@ class TaskDataset(pl.LightningDataModule):
@property @property
def generator(self): def generator(self):
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM) if hasattr(self, "model"):
return generator seed = self.model.current_epoch + LARGE_NUM
else:
seed = LARGE_NUM
return generator.manual_seed(seed)
def train_dataloader(self): def train_dataloader(self):
return DataLoader( return DataLoader(
@ -235,11 +246,12 @@ class EnhancerDataset(TaskDataset):
files: Files, files: Files,
valid_minutes=5.0, valid_minutes=5.0,
duration=1.0, duration=1.0,
stride=0.5, stride=None,
sampling_rate=48000, sampling_rate=48000,
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
# augmentations: Optional[Compose] = None,
): ):
super().__init__( super().__init__(
@ -252,6 +264,7 @@ class EnhancerDataset(TaskDataset):
matching_function=matching_function, matching_function=matching_function,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
# augmentations=augmentations,
) )
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate