default stride None
This commit is contained in:
parent
75ebef2462
commit
5dc5fd8f90
|
|
@ -14,6 +14,9 @@ from enhancer.utils.config import Files
|
||||||
from enhancer.utils.io import Audio
|
from enhancer.utils.io import Audio
|
||||||
from enhancer.utils.random import create_unique_rng
|
from enhancer.utils.random import create_unique_rng
|
||||||
|
|
||||||
|
# from torch_audiomentations import Compose
|
||||||
|
|
||||||
|
|
||||||
LARGE_NUM = 2147483647
|
LARGE_NUM = 2147483647
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
# augmentations: Optional[Compose] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
else:
|
else:
|
||||||
raise ValueError("valid_minutes must be greater than 0")
|
raise ValueError("valid_minutes must be greater than 0")
|
||||||
|
|
||||||
|
# self.augmentations = augmentations
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
prepare train/validation/test data splits
|
prepare train/validation/test data splits
|
||||||
|
|
@ -161,7 +167,9 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
if total_dur < self.duration:
|
if total_dur < self.duration:
|
||||||
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
||||||
else:
|
else:
|
||||||
num_segments = round(total_dur / self.duration)
|
num_segments = self.get_num_segments(
|
||||||
|
total_dur, self.duration, self.duration
|
||||||
|
)
|
||||||
for index in range(num_segments):
|
for index in range(num_segments):
|
||||||
start_time = index * self.duration
|
start_time = index * self.duration
|
||||||
metadata.append(
|
metadata.append(
|
||||||
|
|
@ -175,8 +183,11 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
@property
|
@property
|
||||||
def generator(self):
|
def generator(self):
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator = generator.manual_seed(self.model.current_epoch + LARGE_NUM)
|
if hasattr(self, "model"):
|
||||||
return generator
|
seed = self.model.current_epoch + LARGE_NUM
|
||||||
|
else:
|
||||||
|
seed = LARGE_NUM
|
||||||
|
return generator.manual_seed(seed)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
|
|
@ -235,11 +246,12 @@ class EnhancerDataset(TaskDataset):
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes=5.0,
|
valid_minutes=5.0,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
stride=0.5,
|
stride=None,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
# augmentations: Optional[Compose] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -252,6 +264,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
# augmentations=augmentations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue