split validation criterion

This commit is contained in:
shahules786 2022-10-27 15:19:02 +05:30
parent 47bbee2c32
commit e1963ff001
1 changed files with 24 additions and 9 deletions

View File

@ -1,8 +1,10 @@
import math
import multiprocessing
import os
from pathlib import Path
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
@ -119,16 +121,29 @@ class TaskDataset(pl.LightningDataModule):
):
valid_minutes *= 60
valid_min_now = 0.0
valid_sec_now = 0.0
valid_indices = []
random_indices = list(range(0, len(data)))
rng = create_unique_rng(random_state)
rng.shuffle(random_indices)
i = 0
while valid_min_now <= valid_minutes:
valid_indices.append(random_indices[i])
valid_min_now += data[random_indices[i]]["duration"]
i += 1
all_speakers = np.unique(
[
(Path(file["clean"]).name.split("_")[0], file["duration"])
for file in data
]
)
possible_indices = list(range(0, len(all_speakers)))
rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
file_indices = [
i
for i, file in enumerate(data)
if speaker_name == Path(file["clean"]).name.split("_")[0]
]
for i in file_indices:
valid_indices.append(i)
valid_sec_now += data[i]["duration"]
train_data = [
item for i, item in enumerate(data) if i not in valid_indices