split validation criterion

This commit is contained in:
shahules786 2022-10-27 15:19:02 +05:30
parent 47bbee2c32
commit e1963ff001
1 changed files with 24 additions and 9 deletions

View File

@ -1,8 +1,10 @@
import math import math
import multiprocessing import multiprocessing
import os import os
from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -119,16 +121,29 @@ class TaskDataset(pl.LightningDataModule):
): ):
valid_minutes *= 60 valid_minutes *= 60
valid_min_now = 0.0 valid_sec_now = 0.0
valid_indices = [] valid_indices = []
random_indices = list(range(0, len(data))) all_speakers = np.unique(
rng = create_unique_rng(random_state) [
rng.shuffle(random_indices) (Path(file["clean"]).name.split("_")[0], file["duration"])
i = 0 for file in data
while valid_min_now <= valid_minutes: ]
valid_indices.append(random_indices[i]) )
valid_min_now += data[random_indices[i]]["duration"] possible_indices = list(range(0, len(all_speakers)))
i += 1 rng = create_unique_rng(len(all_speakers))
while valid_sec_now <= valid_minutes:
speaker_index = rng.choice(possible_indices)
possible_indices.remove(speaker_index)
speaker_name = all_speakers[speaker_index]
file_indices = [
i
for i, file in enumerate(data)
if speaker_name == Path(file["clean"]).name.split("_")[0]
]
for i in file_indices:
valid_indices.append(i)
valid_sec_now += data[i]["duration"]
train_data = [ train_data = [
item for i, item in enumerate(data) if i not in valid_indices item for i, item in enumerate(data) if i not in valid_indices