refactor data modules

This commit is contained in:
shahules786 2022-10-05 15:05:40 +05:30
parent 451058c29d
commit b92310c93d
3 changed files with 169 additions and 123 deletions

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

View File

@ -12,8 +12,8 @@ from enhancer.utils.io import Audio
from enhancer.utils import check_files from enhancer.utils import check_files
from enhancer.utils.config import Files from enhancer.utils.config import Files
class TrainDataset(IterableDataset):
class TrainDataset(IterableDataset):
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
@ -23,8 +23,8 @@ class TrainDataset(IterableDataset):
def __len__(self): def __len__(self):
return self.dataset.train__len__() return self.dataset.train__len__()
class ValidDataset(Dataset):
class ValidDataset(Dataset):
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
@ -34,8 +34,8 @@ class ValidDataset(Dataset):
def __len__(self): def __len__(self):
return self.dataset.val__len__() return self.dataset.val__len__()
class TaskDataset(pl.LightningDataModule):
class TaskDataset(pl.LightningDataModule):
def __init__( def __init__(
self, self,
name: str, name: str,
@ -45,7 +45,8 @@ class TaskDataset(pl.LightningDataModule):
sampling_rate: int = 48000, sampling_rate: int = 48000,
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers:Optional[int]=None): num_workers: Optional[int] = None,
):
super().__init__() super().__init__()
self.name = name self.name = name
@ -65,14 +66,16 @@ class TaskDataset(pl.LightningDataModule):
train_clean = os.path.join(self.root_dir, self.files.train_clean) train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir, self.files.train_noisy) train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean, fp = Fileprocessor.from_name(
train_noisy, self.matching_function) self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict() self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir, self.files.test_clean) val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir, self.files.test_noisy) val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean, fp = Fileprocessor.from_name(
val_noisy, self.matching_function) self.name, val_clean, val_noisy, self.matching_function
)
val_data = fp.prepare_matching_dict() val_data = fp.prepare_matching_dict()
for item in val_data: for item in val_data:
@ -82,13 +85,24 @@ class TaskDataset(pl.LightningDataModule):
num_segments = round(total_dur / self.duration) num_segments = round(total_dur / self.duration)
for index in range(num_segments): for index in range(num_segments):
start_time = index * self.duration start_time = index * self.duration
self._validation.append(({"clean":clean,"noisy":noisy}, self._validation.append(
start_time)) ({"clean": clean, "noisy": noisy}, start_time)
)
def train_dataloader(self): def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self): def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset): class EnhancerDataset(TaskDataset):
""" """
@ -100,7 +114,7 @@ class EnhancerDataset(TaskDataset):
root directory of the dataset containing clean/noisy folders root directory of the dataset containing clean/noisy folders
files : Files files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer cli/train_config/dataset) folder names (refer enhancer.utils.Files dataclass)
duration : float duration : float
expected audio duration of single audio sample for training expected audio duration of single audio sample for training
sampling_rate : int sampling_rate : int
@ -126,7 +140,8 @@ class EnhancerDataset(TaskDataset):
sampling_rate=48000, sampling_rate=48000,
matching_function=None, matching_function=None,
batch_size=32, batch_size=32,
num_workers:Optional[int]=None): num_workers: Optional[int] = None,
):
super().__init__( super().__init__(
name=name, name=name,
@ -137,7 +152,6 @@ class EnhancerDataset(TaskDataset):
matching_function=matching_function, matching_function=matching_function,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
) )
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
@ -155,9 +169,12 @@ class EnhancerDataset(TaskDataset):
while True: while True:
file_dict,*_ = rng.choices(self.train_data,k=1, file_dict, *_ = rng.choices(
weights=[file["duration"] for file in self.train_data]) self.train_data,
file_duration = file_dict['duration'] k=1,
weights=[file["duration"] for file in self.train_data],
)
file_duration = file_dict["duration"]
start_time = round(rng.uniform(0, file_duration - self.duration), 2) start_time = round(rng.uniform(0, file_duration - self.duration), 2)
data = self.prepare_segment(file_dict, start_time) data = self.prepare_segment(file_dict, start_time)
yield data yield data
@ -167,18 +184,36 @@ class EnhancerDataset(TaskDataset):
def prepare_segment(self, file_dict: dict, start_time: float): def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(file_dict["clean"], clean_segment = self.audio(
offset=start_time,duration=self.duration) file_dict["clean"], offset=start_time, duration=self.duration
noisy_segment = self.audio(file_dict["noisy"], )
offset=start_time,duration=self.duration) noisy_segment = self.audio(
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1]))) file_dict["noisy"], offset=start_time, duration=self.duration
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1]))) )
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment} return {"clean": clean_segment, "noisy": noisy_segment}
def train__len__(self): def train__len__(self):
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration) return math.ceil(
sum([file["duration"] for file in self.train_data]) / self.duration
)
def val__len__(self): def val__len__(self):
return len(self._validation) return len(self._validation)

View File

@ -6,7 +6,11 @@ from scipy.io import wavfile
MATCHING_FNS = ("one_to_one", "one_to_many") MATCHING_FNS = ("one_to_one", "one_to_many")
class ProcessorFunctions: class ProcessorFunctions:
"""
Preprocessing methods for different types of speech enhacement datasets.
"""
@staticmethod @staticmethod
def one_to_one(clean_path, noisy_path): def one_to_one(clean_path, noisy_path):
@ -15,19 +19,33 @@ class ProcessorFunctions:
""" """
matching_wavfiles = list() matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))] file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
noisy_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
]
common_filenames = np.intersect1d(noisy_filenames, clean_filenames) common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
for file_name in common_filenames: for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) sr_clean, clean_file = wavfile.read(
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) os.path.join(clean_path, file_name)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and )
(sr_clean==sr_noisy)): sr_noisy, noisy_file = wavfile.read(
os.path.join(noisy_path, file_name)
)
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append( matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), {
"duration":clean_file.shape[-1]/sr_clean} "clean": os.path.join(clean_path, file_name),
"noisy": os.path.join(noisy_path, file_name),
"duration": clean_file.shape[-1] / sr_clean,
}
) )
return matching_wavfiles return matching_wavfiles
@ -38,42 +56,42 @@ class ProcessorFunctions:
""" """
matching_wavfiles = dict() matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
for clean_file in clean_filenames: for clean_file in clean_filenames:
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav")) noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}.wav")
)
for noisy_file in noisy_filenames: for noisy_file in noisy_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, clean_file)
)
sr_noisy, noisy_file = wavfile.read(noisy_file) sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
(sr_clean==sr_noisy)): sr_clean == sr_noisy
):
matching_wavfiles.update( matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, {
"duration":clean_file.shape[-1]/sr_clean} "clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file,
"duration": clean_file.shape[-1] / sr_clean,
}
) )
return matching_wavfiles return matching_wavfiles
class Fileprocessor: class Fileprocessor:
def __init__(self, clean_dir, noisy_dir, matching_function=None):
def __init__(
self,
clean_dir,
noisy_dir,
matching_function = None
):
self.clean_dir = clean_dir self.clean_dir = clean_dir
self.noisy_dir = noisy_dir self.noisy_dir = noisy_dir
self.matching_function = matching_function self.matching_function = matching_function
@classmethod @classmethod
def from_name(cls, def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
name:str,
clean_dir,
noisy_dir,
matching_function=None
):
if matching_function is None: if matching_function is None:
if name.lower() == "vctk": if name.lower() == "vctk":
@ -82,11 +100,15 @@ class Fileprocessor:
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else: else:
if matching_function not in MATCHING_FNS: if matching_function not in MATCHING_FNS:
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") raise ValueError(
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
)
else: else:
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) return cls(
clean_dir,
noisy_dir,
getattr(ProcessorFunctions, matching_function),
)
def prepare_matching_dict(self): def prepare_matching_dict(self):
@ -94,15 +116,3 @@ class Fileprocessor:
raise ValueError("Not a valid matching function") raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir, self.noisy_dir) return self.matching_function(self.clean_dir, self.noisy_dir)