refactor data modules

This commit is contained in:
shahules786 2022-10-05 15:05:40 +05:30
parent 451058c29d
commit b92310c93d
3 changed files with 169 additions and 123 deletions

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

View File

@ -12,9 +12,9 @@ from enhancer.utils.io import Audio
from enhancer.utils import check_files
from enhancer.utils.config import Files
class TrainDataset(IterableDataset):
def __init__(self,dataset):
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
@ -23,88 +23,102 @@ class TrainDataset(IterableDataset):
def __len__(self):
return self.dataset.train__len__()
class ValidDataset(Dataset):
def __init__(self,dataset):
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self,idx):
def __getitem__(self, idx):
return self.dataset.val__getitem__(idx)
def __len__(self):
return self.dataset.val__len__()
class TaskDataset(pl.LightningDataModule):
class TaskDataset(pl.LightningDataModule):
def __init__(
self,
name:str,
root_dir:str,
files:Files,
duration:float=1.0,
sampling_rate:int=48000,
matching_function = None,
name: str,
root_dir: str,
files: Files,
duration: float = 1.0,
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
num_workers:Optional[int]=None):
num_workers: Optional[int] = None,
):
super().__init__()
self.name = name
self.files,self.root_dir = check_files(root_dir,files)
self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count()//2
num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None):
if stage in ("fit",None):
if stage in ("fit", None):
train_clean = os.path.join(self.root_dir,self.files.train_clean)
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean,
train_noisy, self.matching_function)
train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir,self.files.test_clean)
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean,
val_noisy, self.matching_function)
val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(
self.name, val_clean, val_noisy, self.matching_function
)
val_data = fp.prepare_matching_dict()
for item in val_data:
clean,noisy,total_dur = item.values()
clean, noisy, total_dur = item.values()
if total_dur < self.duration:
continue
num_segments = round(total_dur/self.duration)
num_segments = round(total_dur / self.duration)
for index in range(num_segments):
start_time = index * self.duration
self._validation.append(({"clean":clean,"noisy":noisy},
start_time))
self._validation.append(
({"clean": clean, "noisy": noisy}, start_time)
)
def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset):
"""
Dataset object for creating clean-noisy speech enhancement datasets
paramters:
name : str
name of the dataset
name of the dataset
root_dir : str
root directory of the dataset containing clean/noisy folders
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer cli/train_config/dataset)
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer enhancer.utils.Files dataclass)
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
desired sampling rate
desired sampling rate
batch_size : int
batch size of each batch
num_workers : int
@ -114,71 +128,92 @@ class EnhancerDataset(TaskDataset):
use one_to_one mapping for datasets with one noisy file for each clean file
use one_to_many mapping for multiple noisy files for each clean file
"""
def __init__(
self,
name:str,
root_dir:str,
files:Files,
name: str,
root_dir: str,
files: Files,
duration=1.0,
sampling_rate=48000,
matching_function=None,
batch_size=32,
num_workers:Optional[int]=None):
num_workers: Optional[int] = None,
):
super().__init__(
name=name,
root_dir=root_dir,
files=files,
sampling_rate=sampling_rate,
duration=duration,
matching_function = matching_function,
matching_function=matching_function,
batch_size=batch_size,
num_workers = num_workers,
num_workers=num_workers,
)
self.sampling_rate = sampling_rate
self.files = files
self.duration = max(1.0,duration)
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True)
self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
def setup(self, stage: Optional[str] = None):
def setup(self, stage:Optional[str]=None):
super().setup(stage=stage)
def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch)
rng = create_unique_rng(self.model.current_epoch)
while True:
file_dict,*_ = rng.choices(self.train_data,k=1,
weights=[file["duration"] for file in self.train_data])
file_duration = file_dict['duration']
start_time = round(rng.uniform(0,file_duration- self.duration),2)
data = self.prepare_segment(file_dict,start_time)
file_dict, *_ = rng.choices(
self.train_data,
k=1,
weights=[file["duration"] for file in self.train_data],
)
file_duration = file_dict["duration"]
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
data = self.prepare_segment(file_dict, start_time)
yield data
def val__getitem__(self,idx):
def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx])
def prepare_segment(self,file_dict:dict, start_time:float):
clean_segment = self.audio(file_dict["clean"],
offset=start_time,duration=self.duration)
noisy_segment = self.audio(file_dict["noisy"],
offset=start_time,duration=self.duration)
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1])))
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1])))
return {"clean": clean_segment,"noisy":noisy_segment}
def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration
)
noisy_segment = self.audio(
file_dict["noisy"], offset=start_time, duration=self.duration
)
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment}
def train__len__(self):
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
return math.ceil(
sum([file["duration"] for file in self.train_data]) / self.duration
)
def val__len__(self):
return len(self._validation)

View File

@ -4,105 +4,115 @@ from re import S
import numpy as np
from scipy.io import wavfile
MATCHING_FNS = ("one_to_one","one_to_many")
MATCHING_FNS = ("one_to_one", "one_to_many")
class ProcessorFunctions:
"""
Preprocessing methods for different types of speech enhacement datasets.
"""
@staticmethod
def one_to_one(clean_path,noisy_path):
def one_to_one(clean_path, noisy_path):
"""
One clean audio can have only one noisy audio file
"""
matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))]
common_filenames = np.intersect1d(noisy_filenames,clean_filenames)
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
noisy_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
]
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr_noisy)):
sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, file_name)
)
sr_noisy, noisy_file = wavfile.read(
os.path.join(noisy_path, file_name)
)
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
"duration":clean_file.shape[-1]/sr_clean}
)
{
"clean": os.path.join(clean_path, file_name),
"noisy": os.path.join(noisy_path, file_name),
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles
@staticmethod
def one_to_many(clean_path,noisy_path):
def one_to_many(clean_path, noisy_path):
"""
One clean audio have multiple noisy audio files
"""
matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
for clean_file in clean_filenames:
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav"))
noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}.wav")
)
for noisy_file in noisy_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, clean_file)
)
sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr_noisy)):
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
"duration":clean_file.shape[-1]/sr_clean}
)
{
"clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file,
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles
class Fileprocessor:
def __init__(
self,
clean_dir,
noisy_dir,
matching_function = None
):
def __init__(self, clean_dir, noisy_dir, matching_function=None):
self.clean_dir = clean_dir
self.noisy_dir = noisy_dir
self.matching_function = matching_function
@classmethod
def from_name(cls,
name:str,
clean_dir,
noisy_dir,
matching_function=None
):
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None:
if name.lower() == "vctk":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else:
if matching_function not in MATCHING_FNS:
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
raise ValueError(
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
)
else:
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
return cls(
clean_dir,
noisy_dir,
getattr(ProcessorFunctions, matching_function),
)
def prepare_matching_dict(self):
if self.matching_function is None:
raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir,self.noisy_dir)
return self.matching_function(self.clean_dir, self.noisy_dir)