split validation criterion
This commit is contained in:
parent
47bbee2c32
commit
e1963ff001
|
|
@ -1,8 +1,10 @@
|
|||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -119,16 +121,29 @@ class TaskDataset(pl.LightningDataModule):
|
|||
):
|
||||
|
||||
valid_minutes *= 60
|
||||
valid_min_now = 0.0
|
||||
valid_sec_now = 0.0
|
||||
valid_indices = []
|
||||
random_indices = list(range(0, len(data)))
|
||||
rng = create_unique_rng(random_state)
|
||||
rng.shuffle(random_indices)
|
||||
i = 0
|
||||
while valid_min_now <= valid_minutes:
|
||||
valid_indices.append(random_indices[i])
|
||||
valid_min_now += data[random_indices[i]]["duration"]
|
||||
i += 1
|
||||
all_speakers = np.unique(
|
||||
[
|
||||
(Path(file["clean"]).name.split("_")[0], file["duration"])
|
||||
for file in data
|
||||
]
|
||||
)
|
||||
possible_indices = list(range(0, len(all_speakers)))
|
||||
rng = create_unique_rng(len(all_speakers))
|
||||
|
||||
while valid_sec_now <= valid_minutes:
|
||||
speaker_index = rng.choice(possible_indices)
|
||||
possible_indices.remove(speaker_index)
|
||||
speaker_name = all_speakers[speaker_index]
|
||||
file_indices = [
|
||||
i
|
||||
for i, file in enumerate(data)
|
||||
if speaker_name == Path(file["clean"]).name.split("_")[0]
|
||||
]
|
||||
for i in file_indices:
|
||||
valid_indices.append(i)
|
||||
valid_sec_now += data[i]["duration"]
|
||||
|
||||
train_data = [
|
||||
item for i, item in enumerate(data) if i not in valid_indices
|
||||
|
|
|
|||
Loading…
Reference in New Issue