Merge pull request #9 from shahules786/dev-reformat

refactor data modules
This commit is contained in:
Shahul ES 2022-10-05 15:06:57 +05:30 committed by GitHub
commit f40e7b500c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 169 additions and 123 deletions

View File

@ -0,0 +1 @@
from enhancer.data.dataset import EnhancerDataset

View File

@ -12,8 +12,8 @@ from enhancer.utils.io import Audio
from enhancer.utils import check_files
from enhancer.utils.config import Files
class TrainDataset(IterableDataset):
class TrainDataset(IterableDataset):
def __init__(self, dataset):
self.dataset = dataset
@ -23,8 +23,8 @@ class TrainDataset(IterableDataset):
def __len__(self):
return self.dataset.train__len__()
class ValidDataset(Dataset):
class ValidDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
@ -34,8 +34,8 @@ class ValidDataset(Dataset):
def __len__(self):
return self.dataset.val__len__()
class TaskDataset(pl.LightningDataModule):
class TaskDataset(pl.LightningDataModule):
def __init__(
self,
name: str,
@ -45,7 +45,8 @@ class TaskDataset(pl.LightningDataModule):
sampling_rate: int = 48000,
matching_function=None,
batch_size=32,
num_workers:Optional[int]=None):
num_workers: Optional[int] = None,
):
super().__init__()
self.name = name
@ -65,14 +66,16 @@ class TaskDataset(pl.LightningDataModule):
train_clean = os.path.join(self.root_dir, self.files.train_clean)
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean,
train_noisy, self.matching_function)
fp = Fileprocessor.from_name(
self.name, train_clean, train_noisy, self.matching_function
)
self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir, self.files.test_clean)
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean,
val_noisy, self.matching_function)
fp = Fileprocessor.from_name(
self.name, val_clean, val_noisy, self.matching_function
)
val_data = fp.prepare_matching_dict()
for item in val_data:
@ -82,13 +85,24 @@ class TaskDataset(pl.LightningDataModule):
num_segments = round(total_dur / self.duration)
for index in range(num_segments):
start_time = index * self.duration
self._validation.append(({"clean":clean,"noisy":noisy},
start_time))
self._validation.append(
({"clean": clean, "noisy": noisy}, start_time)
)
def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
return DataLoader(
TrainDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
return DataLoader(
ValidDataset(self),
batch_size=self.batch_size,
num_workers=self.num_workers,
)
class EnhancerDataset(TaskDataset):
"""
@ -100,7 +114,7 @@ class EnhancerDataset(TaskDataset):
root directory of the dataset containing clean/noisy folders
files : Files
dataclass containing train_clean, train_noisy, test_clean, test_noisy
folder names (refer cli/train_config/dataset)
folder names (refer enhancer.utils.Files dataclass)
duration : float
expected audio duration of single audio sample for training
sampling_rate : int
@ -126,7 +140,8 @@ class EnhancerDataset(TaskDataset):
sampling_rate=48000,
matching_function=None,
batch_size=32,
num_workers:Optional[int]=None):
num_workers: Optional[int] = None,
):
super().__init__(
name=name,
@ -137,7 +152,6 @@ class EnhancerDataset(TaskDataset):
matching_function=matching_function,
batch_size=batch_size,
num_workers=num_workers,
)
self.sampling_rate = sampling_rate
@ -155,9 +169,12 @@ class EnhancerDataset(TaskDataset):
while True:
file_dict,*_ = rng.choices(self.train_data,k=1,
weights=[file["duration"] for file in self.train_data])
file_duration = file_dict['duration']
file_dict, *_ = rng.choices(
self.train_data,
k=1,
weights=[file["duration"] for file in self.train_data],
)
file_duration = file_dict["duration"]
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
data = self.prepare_segment(file_dict, start_time)
yield data
@ -167,18 +184,36 @@ class EnhancerDataset(TaskDataset):
def prepare_segment(self, file_dict: dict, start_time: float):
clean_segment = self.audio(file_dict["clean"],
offset=start_time,duration=self.duration)
noisy_segment = self.audio(file_dict["noisy"],
offset=start_time,duration=self.duration)
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1])))
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1])))
clean_segment = self.audio(
file_dict["clean"], offset=start_time, duration=self.duration
)
noisy_segment = self.audio(
file_dict["noisy"], offset=start_time, duration=self.duration
)
clean_segment = F.pad(
clean_segment,
(
0,
int(
self.duration * self.sampling_rate - clean_segment.shape[-1]
),
),
)
noisy_segment = F.pad(
noisy_segment,
(
0,
int(
self.duration * self.sampling_rate - noisy_segment.shape[-1]
),
),
)
return {"clean": clean_segment, "noisy": noisy_segment}
def train__len__(self):
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
return math.ceil(
sum([file["duration"] for file in self.train_data]) / self.duration
)
def val__len__(self):
return len(self._validation)

View File

@ -6,7 +6,11 @@ from scipy.io import wavfile
MATCHING_FNS = ("one_to_one", "one_to_many")
class ProcessorFunctions:
"""
Preprocessing methods for different types of speech enhacement datasets.
"""
@staticmethod
def one_to_one(clean_path, noisy_path):
@ -15,19 +19,33 @@ class ProcessorFunctions:
"""
matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))]
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
noisy_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
]
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
for file_name in common_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr_noisy)):
sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, file_name)
)
sr_noisy, noisy_file = wavfile.read(
os.path.join(noisy_path, file_name)
)
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
"duration":clean_file.shape[-1]/sr_clean}
{
"clean": os.path.join(clean_path, file_name),
"noisy": os.path.join(noisy_path, file_name),
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles
@ -38,42 +56,42 @@ class ProcessorFunctions:
"""
matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
clean_filenames = [
file.split("/")[-1]
for file in glob.glob(os.path.join(clean_path, "*.wav"))
]
for clean_file in clean_filenames:
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav"))
noisy_filenames = glob.glob(
os.path.join(noisy_path, f"*_{clean_file}.wav")
)
for noisy_file in noisy_filenames:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
sr_clean, clean_file = wavfile.read(
os.path.join(clean_path, clean_file)
)
sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr_noisy)):
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
sr_clean == sr_noisy
):
matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
"duration":clean_file.shape[-1]/sr_clean}
{
"clean": os.path.join(clean_path, clean_file),
"noisy": noisy_file,
"duration": clean_file.shape[-1] / sr_clean,
}
)
return matching_wavfiles
class Fileprocessor:
def __init__(
self,
clean_dir,
noisy_dir,
matching_function = None
):
def __init__(self, clean_dir, noisy_dir, matching_function=None):
self.clean_dir = clean_dir
self.noisy_dir = noisy_dir
self.matching_function = matching_function
@classmethod
def from_name(cls,
name:str,
clean_dir,
noisy_dir,
matching_function=None
):
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
if matching_function is None:
if name.lower() == "vctk":
@ -82,11 +100,15 @@ class Fileprocessor:
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
else:
if matching_function not in MATCHING_FNS:
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
raise ValueError(
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
)
else:
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
return cls(
clean_dir,
noisy_dir,
getattr(ProcessorFunctions, matching_function),
)
def prepare_matching_dict(self):
@ -94,15 +116,3 @@ class Fileprocessor:
raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir, self.noisy_dir)