From e1963ff001fe7785ca809a0df48799077938c8fc Mon Sep 17 00:00:00 2001 From: shahules786 Date: Thu, 27 Oct 2022 15:19:02 +0530 Subject: [PATCH] split validation criterion --- enhancer/data/dataset.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index f05fd6b..dac2c50 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -1,8 +1,10 @@ import math import multiprocessing import os +from pathlib import Path from typing import Optional +import numpy as np import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -119,16 +121,29 @@ class TaskDataset(pl.LightningDataModule): ): valid_minutes *= 60 - valid_min_now = 0.0 + valid_sec_now = 0.0 valid_indices = [] - random_indices = list(range(0, len(data))) - rng = create_unique_rng(random_state) - rng.shuffle(random_indices) - i = 0 - while valid_min_now <= valid_minutes: - valid_indices.append(random_indices[i]) - valid_min_now += data[random_indices[i]]["duration"] - i += 1 + all_speakers = np.unique( + [ + (Path(file["clean"]).name.split("_")[0], file["duration"]) + for file in data + ] + ) + possible_indices = list(range(0, len(all_speakers))) + rng = create_unique_rng(len(all_speakers)) + + while valid_sec_now <= valid_minutes: + speaker_index = rng.choice(possible_indices) + possible_indices.remove(speaker_index) + speaker_name = all_speakers[speaker_index] + file_indices = [ + i + for i, file in enumerate(data) + if speaker_name == Path(file["clean"]).name.split("_")[0] + ] + for i in file_indices: + valid_indices.append(i) + valid_sec_now += data[i]["duration"] train_data = [ item for i, item in enumerate(data) if i not in valid_indices