fix sampling bugs

This commit is contained in:
shahules786 2022-08-23 13:33:46 +05:30
parent 54a4364fb9
commit 65540148f7
1 changed files with 43 additions and 20 deletions

View File

@ -1,18 +1,37 @@
import glob import glob
import math import math
import numpy as np
import os import os
from scipy.io import wavfile from scipy.io import wavfile
from torch.utils.data import IterableDataset from torch.utils.data import IterableDataset
import torch.nn.functional as F
from enhancer.utils.random import create_unique_rng from enhancer.utils.random import create_unique_rng
from enhancer.utils.io import Audio from enhancer.utils.io import Audio
class VctkDataset:
def __init__(self):
pass
def train_loader(self):
pass
def valid_loader(self):
pass
def test_loader(self):
pass
class Vctk(IterableDataset): class Vctk(IterableDataset):
"""Dataset object for Voice Bank Corpus (VCTK) Dataset""" """Dataset object for Voice Bank Corpus (VCTK) Dataset"""
def __init__(self,clean_path,noisy_path,duration=1,sampling_rate=16000,num_samples=None): def __init__(self,clean_path,noisy_path,duration=1.0,sampling_rate=48000):
if not os.path.isdir(clean_path): if not os.path.isdir(clean_path):
raise ValueError(f"{clean_path} is not a valid directory") raise ValueError(f"{clean_path} is not a valid directory")
@ -23,26 +42,28 @@ class Vctk(IterableDataset):
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.clean_path = clean_path self.clean_path = clean_path
self.noisy_path = noisy_path self.noisy_path = noisy_path
self.wav_samples =[file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] self.files_duration = self.get_matching_files_duration()
self.wav_samples = list(self.files_duration.keys())
if num_samples is None:
self.num_samples = len(self.wav_samples)
else:
self.num_samples = num_samples
self.duration = max(1.0,duration) self.duration = max(1.0,duration)
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True) self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True)
self.files_duration = self.get_files_duration()
def get_file_duration(self): def get_matching_files_duration(self):
files_duration = {} matching_wavfiles_dur = dict()
for file in self.clean_path: clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(self.clean_path,"*.wav"))]
wavfile = wavfile.read(os.path.join(self.clean_path,file),rate=self.sampling_rate) noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(self.noisy_path,"*.wav"))]
files_duration.update({file:math.ceil(wavfile/self.sampling_rate)}) common_filenames = np.intersect1d(noisy_filenames,clean_filenames)
return files_duration for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(self.clean_path,file_name))
sr_noisy, noisy_file = wavfile.read(os.path.join(self.noisy_path,file_name))
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==self.sampling_rate) and
(sr_noisy==self.sampling_rate)):
matching_wavfiles_dur.update({file_name:(clean_file.shape[-1]/self.sampling_rate)})
return matching_wavfiles_dur
def __iter__(self): def __iter__(self):
@ -50,19 +71,21 @@ class Vctk(IterableDataset):
while True: while True:
file_name = rng.choices(self.wav_samples,k=1) file_name,*_ = rng.choices(self.wav_samples,k=1,
weights=[self.files_duration[file] for file in self.wav_samples])
file_duration = self.files_duration.get(file_name) file_duration = self.files_duration.get(file_name)
start_time = rng.randint(0,math.ceil(file_duration- self.duration)) start_time = round(rng.uniform(0,file_duration- self.duration),2)
data = self.prepare_segment(file_name,start_time) data = self.prepare_segment(file_name,start_time)
yield data yield data
def prepare_segment(self,file_name:str, start_time:int): def prepare_segment(self,file_name:str, start_time:float):
clean_segment = self.audio(os.path.join(self.clean_path,file_name), clean_segment = self.audio(os.path.join(self.clean_path,file_name),
offset=start_time,duration=self.duration) offset=start_time,duration=self.duration)
noisy_segment = self.audio(os.path.join(self.noisy_path,file_name), noisy_segment = self.audio(os.path.join(self.noisy_path,file_name),
offset=start_time,duration=self.duration) offset=start_time,duration=self.duration)
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1])))
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1])))
return {"clean": clean_segment,"noisy":noisy_segment} return {"clean": clean_segment,"noisy":noisy_segment}
def __len__(self): def __len__(self):