vctk dataset

This commit is contained in:
shahules786 2022-08-22 13:25:43 +05:30
parent bcbc82dbad
commit 54a4364fb9
1 changed files with 46 additions and 11 deletions

View File

@ -1,15 +1,18 @@
from genericpath import isdir import glob
import librosa import math
import os import os
from scipy.io import wavfile
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
import torch
from enhancer.utils.random import create_unique_rng
from enhancer.utils.io import Audio
class Vctk(IterableDataset): class Vctk(IterableDataset):
"""Dataset object for Voice Bank Corpus (VCTK) Dataset""" """Dataset object for Voice Bank Corpus (VCTK) Dataset"""
def __init__(self,clean_path,noisy_path,sample_length=1,num_samples=None): def __init__(self,clean_path,noisy_path,duration=1,sampling_rate=16000,num_samples=None):
if not os.path.isdir(clean_path): if not os.path.isdir(clean_path):
raise ValueError(f"{clean_path} is not a valid directory") raise ValueError(f"{clean_path} is not a valid directory")
@ -17,22 +20,54 @@ class Vctk(IterableDataset):
if not os.path.isdir(noisy_path): if not os.path.isdir(noisy_path):
raise ValueError(f"{clean_path} is not a valid directory") raise ValueError(f"{clean_path} is not a valid directory")
self.sampling_rate = sampling_rate
self.clean_path = clean_path self.clean_path = clean_path
self.noisy_path = noisy_path self.noisy_path = noisy_path
self.wav_samples =[file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
if num_samples is None: if num_samples is None:
self.num_samples = len([file for file in os.listdir(clean_path) if file.endswith(".wav")]) self.num_samples = len(self.wav_samples)
else: else:
self.num_samples = num_samples self.num_samples = num_samples
self.sample_length = max(0.1,sample_length) self.duration = max(1.0,duration)
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True)
self.files_duration = self.get_files_duration()
def get_file_duration(self):
files_duration = {}
for file in self.clean_path:
wavfile = wavfile.read(os.path.join(self.clean_path,file),rate=self.sampling_rate)
files_duration.update({file:math.ceil(wavfile/self.sampling_rate)})
return files_duration
def __iter__(self): def __iter__(self):
rng = create_unique_rng(12) ##pass epoch number here
while True:
file_name = rng.choices(self.wav_samples,k=1)
file_duration = self.files_duration.get(file_name)
start_time = rng.randint(0,math.ceil(file_duration- self.duration))
data = self.prepare_segment(file_name,start_time)
yield data
pass def prepare_segment(self,file_name:str, start_time:int):
clean_segment = self.audio(os.path.join(self.clean_path,file_name),
offset=start_time,duration=self.duration)
noisy_segment = self.audio(os.path.join(self.noisy_path,file_name),
offset=start_time,duration=self.duration)
return {"clean": clean_segment,"noisy":noisy_segment}
def __len__(self): def __len__(self):
pass
return math.ceil(sum(self.files_duration.values())/self.duration)