refactor data modules

This commit is contained in:
shahules786 2022-10-05 15:05:40 +05:30
parent 451058c29d
commit b92310c93d
3 changed files with 169 additions and 123 deletions

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

View File

@ -12,9 +12,9 @@ from enhancer.utils.io import Audio
from enhancer.utils import check_files from enhancer.utils import check_files
from enhancer.utils.config import Files from enhancer.utils.config import Files
class TrainDataset(IterableDataset): class TrainDataset(IterableDataset):
def __init__(self, dataset):
def __init__(self,dataset):
self.dataset = dataset self.dataset = dataset
def __iter__(self): def __iter__(self):
@ -23,88 +23,102 @@ class TrainDataset(IterableDataset):
def __len__(self): def __len__(self):
return self.dataset.train__len__() return self.dataset.train__len__()
class ValidDataset(Dataset): class ValidDataset(Dataset):
def __init__(self, dataset):
def __init__(self,dataset):
self.dataset = dataset self.dataset = dataset
def __getitem__(self,idx): def __getitem__(self, idx):
return self.dataset.val__getitem__(idx) return self.dataset.val__getitem__(idx)
def __len__(self): def __len__(self):
return self.dataset.val__len__() return self.dataset.val__len__()
class TaskDataset(pl.LightningDataModule):
class TaskDataset(pl.LightningDataModule):
def __init__( def __init__(
self, self,
name:str, name: str,
root_dir:str, root_dir: str,
files:Files, files: Files,
duration:float=1.0, duration: float = 1.0,
sampling_rate:int=48000, sampling_rate: int = 48000,
matching_function = None, matching_function=None,
batch_size=32, batch_size=32,
num_workers:Optional[int]=None): num_workers: Optional[int] = None,
):
super().__init__() super().__init__()
self.name = name self.name = name
self.files,self.root_dir = check_files(root_dir,files) self.files, self.root_dir = check_files(root_dir, files)
self.duration = duration self.duration = duration
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.batch_size = batch_size self.batch_size = batch_size
self.matching_function = matching_function self.matching_function = matching_function
self._validation = [] self._validation = []
if num_workers is None: if num_workers is None:
num_workers = multiprocessing.cpu_count()//2 num_workers = multiprocessing.cpu_count() // 2
self.num_workers = num_workers self.num_workers = num_workers
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
if stage in ("fit",None): if stage in ("fit", None):
train_clean = os.path.join(self.root_dir,self.files.train_clean) train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir,self.files.train_noisy) train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean, fp = Fileprocessor.from_name(
train_noisy, self.matching_function) self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict() self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir,self.files.test_clean) val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir,self.files.test_noisy) val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean, fp = Fileprocessor.from_name(
val_noisy, self.matching_function) self.name, val_clean, val_noisy, self.matching_function
)
val_data = fp.prepare_matching_dict() val_data = fp.prepare_matching_dict()
for item in val_data: for item in val_data:
clean,noisy,total_dur = item.values() clean, noisy, total_dur = item.values()
if total_dur < self.duration: if total_dur < self.duration:
continue continue
num_segments = round(total_dur/self.duration) num_segments = round(total_dur / self.duration)
for index in range(num_segments): for index in range(num_segments):
start_time = index * self.duration start_time = index * self.duration
self._validation.append(({"clean":clean,"noisy":noisy}, self._validation.append(
start_time)) ({"clean": clean, "noisy": noisy}, start_time)
)
def train_dataloader(self): def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self): def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset): class EnhancerDataset(TaskDataset):
""" """
Dataset object for creating clean-noisy speech enhancement datasets Dataset object for creating clean-noisy speech enhancement datasets
paramters: paramters:
name : str name : str
name of the dataset name of the dataset
root_dir : str root_dir : str
root directory of the dataset containing clean/noisy folders root directory of the dataset containing clean/noisy folders
files : Files files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer cli/train_config/dataset) folder names (refer enhancer.utils.Files dataclass)
duration : float duration : float
expected audio duration of single audio sample for training expected audio duration of single audio sample for training
sampling_rate : int sampling_rate : int
desired sampling rate desired sampling rate
batch_size : int batch_size : int
batch size of each batch batch size of each batch
num_workers : int num_workers : int
@ -114,71 +128,92 @@ class EnhancerDataset(TaskDataset):
use one_to_one mapping for datasets with one noisy file for each clean file use one_to_one mapping for datasets with one noisy file for each clean file
use one_to_many mapping for multiple noisy files for each clean file use one_to_many mapping for multiple noisy files for each clean file
""" """
def __init__( def __init__(
self, self,
name:str, name: str,
root_dir:str, root_dir: str,
files:Files, files: Files,
duration=1.0, duration=1.0,
sampling_rate=48000, sampling_rate=48000,
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers:Optional[int]=None): num_workers: Optional[int] = None,
):
super().__init__( super().__init__(
name=name, name=name,
root_dir=root_dir, root_dir=root_dir,
files=files, files=files,
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
duration=duration, duration=duration,
matching_function = matching_function, matching_function=matching_function,
batch_size=batch_size, batch_size=batch_size,
num_workers = num_workers, num_workers=num_workers,
) )
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.files = files self.files = files
self.duration = max(1.0,duration) self.duration = max(1.0, duration)
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True) self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
def setup(self, stage: Optional[str] = None):
def setup(self, stage:Optional[str]=None):
super().setup(stage=stage) super().setup(stage=stage)
def train__iter__(self): def train__iter__(self):
rng = create_unique_rng(self.model.current_epoch) rng = create_unique_rng(self.model.current_epoch)
while True: while True:
file_dict,*_ = rng.choices(self.train_data,k=1, file_dict, *_ = rng.choices(
weights=[file["duration"] for file in self.train_data]) self.train_data,
file_duration = file_dict['duration'] k=1,
start_time = round(rng.uniform(0,file_duration- self.duration),2) weights=[file["duration"] for file in self.train_data],
data = self.prepare_segment(file_dict,start_time) )
file_duration = file_dict["duration"]
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
data = self.prepare_segment(file_dict, start_time)
yield data yield data
def val__getitem__(self,idx): def val__getitem__(self, idx):
return self.prepare_segment(*self._validation[idx]) return self.prepare_segment(*self._validation[idx])
def prepare_segment(self,file_dict:dict, start_time:float):
clean_segment = self.audio(file_dict["clean"], def prepare_segment(self, file_dict: dict, start_time: float):
offset=start_time,duration=self.duration)
noisy_segment = self.audio(file_dict["noisy"], clean_segment = self.audio(
offset=start_time,duration=self.duration) file_dict["clean"], offset=start_time, duration=self.duration
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1]))) )
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1]))) noisy_segment = self.audio(
return {"clean": clean_segment,"noisy":noisy_segment} file_dict["noisy"], offset=start_time, duration=self.duration
)
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment}
def train__len__(self): def train__len__(self):
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration) return math.ceil(
sum([file["duration"] for file in self.train_data]) / self.duration
)
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)

View File

@ -4,105 +4,115 @@ from re import S
import numpy as np import numpy as np
from scipy.io import wavfile from scipy.io import wavfile
MATCHING_FNS = ("one_to_one","one_to_many") MATCHING_FNS = ("one_to_one", "one_to_many")
class ProcessorFunctions: class ProcessorFunctions:
"""
Preprocessing methods for different types of speech enhacement datasets.
"""
@staticmethod @staticmethod
def one_to_one(clean_path,noisy_path): def one_to_one(clean_path, noisy_path):
""" """
One clean audio can have only one noisy audio file One clean audio can have only one noisy audio file
""" """
matching_wavfiles = list() matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))] file.split("/")[-1]
common_filenames = np.intersect1d(noisy_filenames,clean_filenames) for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
noisy_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
]
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
for file_name in common_filenames: for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) sr_clean, clean_file = wavfile.read(
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) os.path.join(clean_path, file_name)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and )
(sr_clean==sr_noisy)): sr_noisy, noisy_file = wavfile.read(
os.path.join(noisy_path, file_name)
)
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append( matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), {
"duration":clean_file.shape[-1]/sr_clean} "clean": os.path.join(clean_path, file_name),
) "noisy": os.path.join(noisy_path, file_name),
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles return matching_wavfiles
@staticmethod @staticmethod
def one_to_many(clean_path,noisy_path): def one_to_many(clean_path, noisy_path):
""" """
One clean audio have multiple noisy audio files One clean audio have multiple noisy audio files
""" """
matching_wavfiles = dict() matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
for clean_file in clean_filenames: for clean_file in clean_filenames:
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav")) noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}.wav")
)
for noisy_file in noisy_filenames: for noisy_file in noisy_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, clean_file)
)
sr_noisy, noisy_file = wavfile.read(noisy_file) sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
(sr_clean==sr_noisy)): sr_clean == sr_noisy
):
matching_wavfiles.update( matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, {
"duration":clean_file.shape[-1]/sr_clean} "clean": os.path.join(clean_path, clean_file),
) "noisy": noisy_file,
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles return matching_wavfiles
class Fileprocessor: class Fileprocessor:
def __init__(self, clean_dir, noisy_dir, matching_function=None):
def __init__(
self,
clean_dir,
noisy_dir,
matching_function = None
):
self.clean_dir = clean_dir self.clean_dir = clean_dir
self.noisy_dir = noisy_dir self.noisy_dir = noisy_dir
self.matching_function = matching_function self.matching_function = matching_function
@classmethod @classmethod
def from_name(cls, def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
name:str,
clean_dir,
noisy_dir,
matching_function=None
):
if matching_function is None: if matching_function is None:
if name.lower() == "vctk": if name.lower() == "vctk":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
elif name.lower() == "dns-2020": elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else: else:
if matching_function not in MATCHING_FNS: if matching_function not in MATCHING_FNS:
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") raise ValueError(
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
)
else: else:
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) return cls(
clean_dir,
noisy_dir,
getattr(ProcessorFunctions, matching_function),
)
def prepare_matching_dict(self): def prepare_matching_dict(self):
if self.matching_function is None: if self.matching_function is None:
raise ValueError("Not a valid matching function") raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir,self.noisy_dir) return self.matching_function(self.clean_dir, self.noisy_dir)