refactor data modules
This commit is contained in:
parent
451058c29d
commit
b92310c93d
|
|
@ -0,0 +1 @@
|
|||
from enhancer.data.dataset import EnhancerDataset
|
||||
|
|
@ -12,9 +12,9 @@ from enhancer.utils.io import Audio
|
|||
from enhancer.utils import check_files
|
||||
from enhancer.utils.config import Files
|
||||
|
||||
|
||||
class TrainDataset(IterableDataset):
|
||||
|
||||
def __init__(self,dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __iter__(self):
|
||||
|
|
@ -23,88 +23,102 @@ class TrainDataset(IterableDataset):
|
|||
def __len__(self):
|
||||
return self.dataset.train__len__()
|
||||
|
||||
|
||||
class ValidDataset(Dataset):
|
||||
|
||||
def __init__(self,dataset):
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self,idx):
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset.val__getitem__(idx)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset.val__len__()
|
||||
|
||||
class TaskDataset(pl.LightningDataModule):
|
||||
|
||||
class TaskDataset(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
name:str,
|
||||
root_dir:str,
|
||||
files:Files,
|
||||
duration:float=1.0,
|
||||
sampling_rate:int=48000,
|
||||
matching_function = None,
|
||||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
duration: float = 1.0,
|
||||
sampling_rate: int = 48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
self.files,self.root_dir = check_files(root_dir,files)
|
||||
self.files, self.root_dir = check_files(root_dir, files)
|
||||
self.duration = duration
|
||||
self.sampling_rate = sampling_rate
|
||||
self.batch_size = batch_size
|
||||
self.matching_function = matching_function
|
||||
self._validation = []
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count()//2
|
||||
num_workers = multiprocessing.cpu_count() // 2
|
||||
self.num_workers = num_workers
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
if stage in ("fit",None):
|
||||
if stage in ("fit", None):
|
||||
|
||||
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
||||
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,train_clean,
|
||||
train_noisy, self.matching_function)
|
||||
train_clean = os.path.join(self.root_dir, self.files.train_clean)
|
||||
train_noisy = os.path.join(self.root_dir, self.files.train_noisy)
|
||||
fp = Fileprocessor.from_name(
|
||||
self.name, train_clean, train_noisy, self.matching_function
|
||||
)
|
||||
self.train_data = fp.prepare_matching_dict()
|
||||
|
||||
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
||||
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,val_clean,
|
||||
val_noisy, self.matching_function)
|
||||
|
||||
val_clean = os.path.join(self.root_dir, self.files.test_clean)
|
||||
val_noisy = os.path.join(self.root_dir, self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(
|
||||
self.name, val_clean, val_noisy, self.matching_function
|
||||
)
|
||||
val_data = fp.prepare_matching_dict()
|
||||
|
||||
for item in val_data:
|
||||
clean,noisy,total_dur = item.values()
|
||||
clean, noisy, total_dur = item.values()
|
||||
if total_dur < self.duration:
|
||||
continue
|
||||
num_segments = round(total_dur/self.duration)
|
||||
num_segments = round(total_dur / self.duration)
|
||||
for index in range(num_segments):
|
||||
start_time = index * self.duration
|
||||
self._validation.append(({"clean":clean,"noisy":noisy},
|
||||
start_time))
|
||||
self._validation.append(
|
||||
({"clean": clean, "noisy": noisy}, start_time)
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
return DataLoader(
|
||||
TrainDataset(self),
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
return DataLoader(
|
||||
ValidDataset(self),
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
)
|
||||
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
"""
|
||||
Dataset object for creating clean-noisy speech enhancement datasets
|
||||
paramters:
|
||||
name : str
|
||||
name of the dataset
|
||||
name of the dataset
|
||||
root_dir : str
|
||||
root directory of the dataset containing clean/noisy folders
|
||||
files : Files
|
||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||
folder names (refer cli/train_config/dataset)
|
||||
dataclass containing train_clean, train_noisy, test_clean, test_noisy
|
||||
folder names (refer enhancer.utils.Files dataclass)
|
||||
duration : float
|
||||
expected audio duration of single audio sample for training
|
||||
sampling_rate : int
|
||||
desired sampling rate
|
||||
desired sampling rate
|
||||
batch_size : int
|
||||
batch size of each batch
|
||||
num_workers : int
|
||||
|
|
@ -114,71 +128,92 @@ class EnhancerDataset(TaskDataset):
|
|||
use one_to_one mapping for datasets with one noisy file for each clean file
|
||||
use one_to_many mapping for multiple noisy files for each clean file
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name:str,
|
||||
root_dir:str,
|
||||
files:Files,
|
||||
name: str,
|
||||
root_dir: str,
|
||||
files: Files,
|
||||
duration=1.0,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
name=name,
|
||||
root_dir=root_dir,
|
||||
files=files,
|
||||
sampling_rate=sampling_rate,
|
||||
duration=duration,
|
||||
matching_function = matching_function,
|
||||
matching_function=matching_function,
|
||||
batch_size=batch_size,
|
||||
num_workers = num_workers,
|
||||
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
self.files = files
|
||||
self.duration = max(1.0,duration)
|
||||
self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True)
|
||||
self.duration = max(1.0, duration)
|
||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
def setup(self, stage:Optional[str]=None):
|
||||
|
||||
super().setup(stage=stage)
|
||||
|
||||
def train__iter__(self):
|
||||
|
||||
rng = create_unique_rng(self.model.current_epoch)
|
||||
|
||||
rng = create_unique_rng(self.model.current_epoch)
|
||||
|
||||
while True:
|
||||
|
||||
file_dict,*_ = rng.choices(self.train_data,k=1,
|
||||
weights=[file["duration"] for file in self.train_data])
|
||||
file_duration = file_dict['duration']
|
||||
start_time = round(rng.uniform(0,file_duration- self.duration),2)
|
||||
data = self.prepare_segment(file_dict,start_time)
|
||||
file_dict, *_ = rng.choices(
|
||||
self.train_data,
|
||||
k=1,
|
||||
weights=[file["duration"] for file in self.train_data],
|
||||
)
|
||||
file_duration = file_dict["duration"]
|
||||
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
|
||||
data = self.prepare_segment(file_dict, start_time)
|
||||
yield data
|
||||
|
||||
def val__getitem__(self,idx):
|
||||
def val__getitem__(self, idx):
|
||||
return self.prepare_segment(*self._validation[idx])
|
||||
|
||||
def prepare_segment(self,file_dict:dict, start_time:float):
|
||||
|
||||
clean_segment = self.audio(file_dict["clean"],
|
||||
offset=start_time,duration=self.duration)
|
||||
noisy_segment = self.audio(file_dict["noisy"],
|
||||
offset=start_time,duration=self.duration)
|
||||
clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1])))
|
||||
noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1])))
|
||||
return {"clean": clean_segment,"noisy":noisy_segment}
|
||||
|
||||
def prepare_segment(self, file_dict: dict, start_time: float):
|
||||
|
||||
clean_segment = self.audio(
|
||||
file_dict["clean"], offset=start_time, duration=self.duration
|
||||
)
|
||||
noisy_segment = self.audio(
|
||||
file_dict["noisy"], offset=start_time, duration=self.duration
|
||||
)
|
||||
clean_segment = F.pad(
|
||||
clean_segment,
|
||||
(
|
||||
0,
|
||||
int(
|
||||
self.duration * self.sampling_rate - clean_segment.shape[-1]
|
||||
),
|
||||
),
|
||||
)
|
||||
noisy_segment = F.pad(
|
||||
noisy_segment,
|
||||
(
|
||||
0,
|
||||
int(
|
||||
self.duration * self.sampling_rate - noisy_segment.shape[-1]
|
||||
),
|
||||
),
|
||||
)
|
||||
return {"clean": clean_segment, "noisy": noisy_segment}
|
||||
|
||||
def train__len__(self):
|
||||
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
|
||||
return math.ceil(
|
||||
sum([file["duration"] for file in self.train_data]) / self.duration
|
||||
)
|
||||
|
||||
def val__len__(self):
|
||||
return len(self._validation)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,105 +4,115 @@ from re import S
|
|||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
MATCHING_FNS = ("one_to_one","one_to_many")
|
||||
MATCHING_FNS = ("one_to_one", "one_to_many")
|
||||
|
||||
|
||||
class ProcessorFunctions:
|
||||
"""
|
||||
Preprocessing methods for different types of speech enhacement datasets.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def one_to_one(clean_path,noisy_path):
|
||||
def one_to_one(clean_path, noisy_path):
|
||||
"""
|
||||
One clean audio can have only one noisy audio file
|
||||
"""
|
||||
|
||||
matching_wavfiles = list()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))]
|
||||
common_filenames = np.intersect1d(noisy_filenames,clean_filenames)
|
||||
clean_filenames = [
|
||||
file.split("/")[-1]
|
||||
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||
]
|
||||
noisy_filenames = [
|
||||
file.split("/")[-1]
|
||||
for file in glob.glob(os.path.join(noisy_path, "*.wav"))
|
||||
]
|
||||
common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
|
||||
|
||||
for file_name in common_filenames:
|
||||
|
||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
|
||||
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
|
||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||
(sr_clean==sr_noisy)):
|
||||
sr_clean, clean_file = wavfile.read(
|
||||
os.path.join(clean_path, file_name)
|
||||
)
|
||||
sr_noisy, noisy_file = wavfile.read(
|
||||
os.path.join(noisy_path, file_name)
|
||||
)
|
||||
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||
sr_clean == sr_noisy
|
||||
):
|
||||
matching_wavfiles.append(
|
||||
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
|
||||
"duration":clean_file.shape[-1]/sr_clean}
|
||||
)
|
||||
{
|
||||
"clean": os.path.join(clean_path, file_name),
|
||||
"noisy": os.path.join(noisy_path, file_name),
|
||||
"duration": clean_file.shape[-1] / sr_clean,
|
||||
}
|
||||
)
|
||||
return matching_wavfiles
|
||||
|
||||
@staticmethod
|
||||
def one_to_many(clean_path,noisy_path):
|
||||
def one_to_many(clean_path, noisy_path):
|
||||
"""
|
||||
One clean audio have multiple noisy audio files
|
||||
"""
|
||||
|
||||
|
||||
matching_wavfiles = dict()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
clean_filenames = [
|
||||
file.split("/")[-1]
|
||||
for file in glob.glob(os.path.join(clean_path, "*.wav"))
|
||||
]
|
||||
for clean_file in clean_filenames:
|
||||
noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav"))
|
||||
noisy_filenames = glob.glob(
|
||||
os.path.join(noisy_path, f"*_{clean_file}.wav")
|
||||
)
|
||||
for noisy_file in noisy_filenames:
|
||||
|
||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
|
||||
sr_clean, clean_file = wavfile.read(
|
||||
os.path.join(clean_path, clean_file)
|
||||
)
|
||||
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||
(sr_clean==sr_noisy)):
|
||||
if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
|
||||
sr_clean == sr_noisy
|
||||
):
|
||||
matching_wavfiles.update(
|
||||
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
|
||||
"duration":clean_file.shape[-1]/sr_clean}
|
||||
)
|
||||
{
|
||||
"clean": os.path.join(clean_path, clean_file),
|
||||
"noisy": noisy_file,
|
||||
"duration": clean_file.shape[-1] / sr_clean,
|
||||
}
|
||||
)
|
||||
|
||||
return matching_wavfiles
|
||||
|
||||
|
||||
class Fileprocessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
matching_function = None
|
||||
):
|
||||
def __init__(self, clean_dir, noisy_dir, matching_function=None):
|
||||
self.clean_dir = clean_dir
|
||||
self.noisy_dir = noisy_dir
|
||||
self.matching_function = matching_function
|
||||
|
||||
@classmethod
|
||||
def from_name(cls,
|
||||
name:str,
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
matching_function=None
|
||||
):
|
||||
def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
|
||||
|
||||
if matching_function is None:
|
||||
if name.lower() == "vctk":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one)
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
|
||||
elif name.lower() == "dns-2020":
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many)
|
||||
return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
|
||||
else:
|
||||
if matching_function not in MATCHING_FNS:
|
||||
raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}")
|
||||
raise ValueError(
|
||||
f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
|
||||
)
|
||||
else:
|
||||
return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function))
|
||||
|
||||
|
||||
return cls(
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
getattr(ProcessorFunctions, matching_function),
|
||||
)
|
||||
|
||||
def prepare_matching_dict(self):
|
||||
|
||||
if self.matching_function is None:
|
||||
raise ValueError("Not a valid matching function")
|
||||
|
||||
return self.matching_function(self.clean_dir,self.noisy_dir)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return self.matching_function(self.clean_dir, self.noisy_dir)
|
||||
|
|
|
|||
Loading…
Reference in New Issue