stride waveform
This commit is contained in:
parent
415ed8e3d0
commit
edb7f020f7
|
|
@ -132,11 +132,14 @@ class TaskDataset(pl.LightningDataModule):
|
|||
for item in data:
|
||||
clean, noisy, total_dur = item.values()
|
||||
if total_dur < self.duration:
|
||||
continue
|
||||
num_segments = round(total_dur / self.duration)
|
||||
for index in range(num_segments):
|
||||
start_time = index * self.duration
|
||||
metadata.append(({"clean": clean, "noisy": noisy}, start_time))
|
||||
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
||||
else:
|
||||
num_segments = round(total_dur / self.duration)
|
||||
for index in range(num_segments):
|
||||
start_time = index * self.duration
|
||||
metadata.append(
|
||||
({"clean": clean, "noisy": noisy}, start_time)
|
||||
)
|
||||
return metadata
|
||||
|
||||
def train_dataloader(self):
|
||||
|
|
@ -195,6 +198,7 @@ class EnhancerDataset(TaskDataset):
|
|||
files: Files,
|
||||
valid_minutes=5.0,
|
||||
duration=1.0,
|
||||
stride=0.5,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
|
|
@ -217,6 +221,7 @@ class EnhancerDataset(TaskDataset):
|
|||
self.files = files
|
||||
self.duration = max(1.0, duration)
|
||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||
self.stride = stride or duration
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
|
|
@ -234,9 +239,22 @@ class EnhancerDataset(TaskDataset):
|
|||
weights=[file["duration"] for file in self.train_data],
|
||||
)
|
||||
file_duration = file_dict["duration"]
|
||||
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
|
||||
data = self.prepare_segment(file_dict, start_time)
|
||||
yield data
|
||||
num_segments = self.get_num_segments(
|
||||
file_duration, self.duration, self.stride
|
||||
)
|
||||
for index in range(0, num_segments):
|
||||
start_time = index * self.stride
|
||||
yield self.prepare_segment(file_dict, start_time)
|
||||
|
||||
@staticmethod
|
||||
def get_num_segments(file_duration, duration, stride):
|
||||
|
||||
if file_duration < duration:
|
||||
num_segments = 1
|
||||
else:
|
||||
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
||||
|
||||
return num_segments
|
||||
|
||||
def val__getitem__(self, idx):
|
||||
return self.prepare_segment(*self._validation[idx])
|
||||
|
|
@ -273,8 +291,16 @@ class EnhancerDataset(TaskDataset):
|
|||
return {"clean": clean_segment, "noisy": noisy_segment}
|
||||
|
||||
def train__len__(self):
|
||||
|
||||
return math.ceil(
|
||||
sum([file["duration"] for file in self.train_data]) / self.duration
|
||||
sum(
|
||||
[
|
||||
self.get_num_segments(
|
||||
file["duration"], self.duration, self.stride
|
||||
)
|
||||
for file in self.train_data
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def val__len__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue