diff --git a/enhancer/data/__init__.py b/enhancer/data/__init__.py index e69de29..3ec018e 100644 --- a/enhancer/data/__init__.py +++ b/enhancer/data/__init__.py @@ -0,0 +1 @@ +from enhancer.data.dataset import EnhancerDataset \ No newline at end of file diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 4c485c8..d194167 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -12,9 +12,9 @@ from enhancer.utils.io import Audio from enhancer.utils import check_files from enhancer.utils.config import Files + class TrainDataset(IterableDataset): - - def __init__(self,dataset): + def __init__(self, dataset): self.dataset = dataset def __iter__(self): @@ -23,88 +23,102 @@ class TrainDataset(IterableDataset): def __len__(self): return self.dataset.train__len__() + class ValidDataset(Dataset): - - def __init__(self,dataset): + def __init__(self, dataset): self.dataset = dataset - def __getitem__(self,idx): + def __getitem__(self, idx): return self.dataset.val__getitem__(idx) def __len__(self): return self.dataset.val__len__() -class TaskDataset(pl.LightningDataModule): +class TaskDataset(pl.LightningDataModule): def __init__( self, - name:str, - root_dir:str, - files:Files, - duration:float=1.0, - sampling_rate:int=48000, - matching_function = None, + name: str, + root_dir: str, + files: Files, + duration: float = 1.0, + sampling_rate: int = 48000, + matching_function=None, batch_size=32, - num_workers:Optional[int]=None): + num_workers: Optional[int] = None, + ): super().__init__() self.name = name - self.files,self.root_dir = check_files(root_dir,files) + self.files, self.root_dir = check_files(root_dir, files) self.duration = duration self.sampling_rate = sampling_rate self.batch_size = batch_size self.matching_function = matching_function self._validation = [] if num_workers is None: - num_workers = multiprocessing.cpu_count()//2 + num_workers = multiprocessing.cpu_count() // 2 self.num_workers = num_workers def setup(self, stage: Optional[str] = None): - if stage in ("fit",None): + if stage in ("fit", None): - train_clean = os.path.join(self.root_dir,self.files.train_clean) - train_noisy = os.path.join(self.root_dir,self.files.train_noisy) - fp = Fileprocessor.from_name(self.name,train_clean, - train_noisy, self.matching_function) + train_clean = os.path.join(self.root_dir, self.files.train_clean) + train_noisy = os.path.join(self.root_dir, self.files.train_noisy) + fp = Fileprocessor.from_name( + self.name, train_clean, train_noisy, self.matching_function + ) self.train_data = fp.prepare_matching_dict() - - val_clean = os.path.join(self.root_dir,self.files.test_clean) - val_noisy = os.path.join(self.root_dir,self.files.test_noisy) - fp = Fileprocessor.from_name(self.name,val_clean, - val_noisy, self.matching_function) + + val_clean = os.path.join(self.root_dir, self.files.test_clean) + val_noisy = os.path.join(self.root_dir, self.files.test_noisy) + fp = Fileprocessor.from_name( + self.name, val_clean, val_noisy, self.matching_function + ) val_data = fp.prepare_matching_dict() for item in val_data: - clean,noisy,total_dur = item.values() + clean, noisy, total_dur = item.values() if total_dur < self.duration: continue - num_segments = round(total_dur/self.duration) + num_segments = round(total_dur / self.duration) for index in range(num_segments): start_time = index * self.duration - self._validation.append(({"clean":clean,"noisy":noisy}, - start_time)) + self._validation.append( + ({"clean": clean, "noisy": noisy}, start_time) + ) + def train_dataloader(self): - return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) + return DataLoader( + TrainDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) def val_dataloader(self): - return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers) + return DataLoader( + ValidDataset(self), + batch_size=self.batch_size, + num_workers=self.num_workers, + ) + class EnhancerDataset(TaskDataset): """ Dataset object for creating clean-noisy speech enhancement datasets paramters: name : str - name of the dataset + name of the dataset root_dir : str root directory of the dataset containing clean/noisy folders files : Files - dataclass containing train_clean, train_noisy, test_clean, test_noisy - folder names (refer cli/train_config/dataset) + dataclass containing train_clean, train_noisy, test_clean, test_noisy + folder names (refer enhancer.utils.Files dataclass) duration : float expected audio duration of single audio sample for training sampling_rate : int - desired sampling rate + desired sampling rate batch_size : int batch size of each batch num_workers : int @@ -114,71 +128,92 @@ class EnhancerDataset(TaskDataset): use one_to_one mapping for datasets with one noisy file for each clean file use one_to_many mapping for multiple noisy files for each clean file - + """ def __init__( self, - name:str, - root_dir:str, - files:Files, + name: str, + root_dir: str, + files: Files, duration=1.0, sampling_rate=48000, matching_function=None, batch_size=32, - num_workers:Optional[int]=None): - + num_workers: Optional[int] = None, + ): + super().__init__( name=name, root_dir=root_dir, files=files, sampling_rate=sampling_rate, duration=duration, - matching_function = matching_function, + matching_function=matching_function, batch_size=batch_size, - num_workers = num_workers, - + num_workers=num_workers, ) self.sampling_rate = sampling_rate self.files = files - self.duration = max(1.0,duration) - self.audio = Audio(self.sampling_rate,mono=True,return_tensor=True) + self.duration = max(1.0, duration) + self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True) + + def setup(self, stage: Optional[str] = None): - def setup(self, stage:Optional[str]=None): - super().setup(stage=stage) def train__iter__(self): - rng = create_unique_rng(self.model.current_epoch) - + rng = create_unique_rng(self.model.current_epoch) + while True: - file_dict,*_ = rng.choices(self.train_data,k=1, - weights=[file["duration"] for file in self.train_data]) - file_duration = file_dict['duration'] - start_time = round(rng.uniform(0,file_duration- self.duration),2) - data = self.prepare_segment(file_dict,start_time) + file_dict, *_ = rng.choices( + self.train_data, + k=1, + weights=[file["duration"] for file in self.train_data], + ) + file_duration = file_dict["duration"] + start_time = round(rng.uniform(0, file_duration - self.duration), 2) + data = self.prepare_segment(file_dict, start_time) yield data - def val__getitem__(self,idx): + def val__getitem__(self, idx): return self.prepare_segment(*self._validation[idx]) - - def prepare_segment(self,file_dict:dict, start_time:float): - clean_segment = self.audio(file_dict["clean"], - offset=start_time,duration=self.duration) - noisy_segment = self.audio(file_dict["noisy"], - offset=start_time,duration=self.duration) - clean_segment = F.pad(clean_segment,(0,int(self.duration*self.sampling_rate-clean_segment.shape[-1]))) - noisy_segment = F.pad(noisy_segment,(0,int(self.duration*self.sampling_rate-noisy_segment.shape[-1]))) - return {"clean": clean_segment,"noisy":noisy_segment} - + def prepare_segment(self, file_dict: dict, start_time: float): + + clean_segment = self.audio( + file_dict["clean"], offset=start_time, duration=self.duration + ) + noisy_segment = self.audio( + file_dict["noisy"], offset=start_time, duration=self.duration + ) + clean_segment = F.pad( + clean_segment, + ( + 0, + int( + self.duration * self.sampling_rate - clean_segment.shape[-1] + ), + ), + ) + noisy_segment = F.pad( + noisy_segment, + ( + 0, + int( + self.duration * self.sampling_rate - noisy_segment.shape[-1] + ), + ), + ) + return {"clean": clean_segment, "noisy": noisy_segment} + def train__len__(self): - return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration) + return math.ceil( + sum([file["duration"] for file in self.train_data]) / self.duration + ) def val__len__(self): return len(self._validation) - - diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index eab41a0..106f649 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -4,105 +4,115 @@ from re import S import numpy as np from scipy.io import wavfile -MATCHING_FNS = ("one_to_one","one_to_many") +MATCHING_FNS = ("one_to_one", "one_to_many") + class ProcessorFunctions: + """ + Preprocessing methods for different types of speech enhacement datasets. + """ @staticmethod - def one_to_one(clean_path,noisy_path): + def one_to_one(clean_path, noisy_path): """ One clean audio can have only one noisy audio file """ matching_wavfiles = list() - clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] - noisy_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(noisy_path,"*.wav"))] - common_filenames = np.intersect1d(noisy_filenames,clean_filenames) + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] + noisy_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(noisy_path, "*.wav")) + ] + common_filenames = np.intersect1d(noisy_filenames, clean_filenames) for file_name in common_filenames: - sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) - sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) - if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr_noisy)): + sr_clean, clean_file = wavfile.read( + os.path.join(clean_path, file_name) + ) + sr_noisy, noisy_file = wavfile.read( + os.path.join(noisy_path, file_name) + ) + if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( + sr_clean == sr_noisy + ): matching_wavfiles.append( - {"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), - "duration":clean_file.shape[-1]/sr_clean} - ) + { + "clean": os.path.join(clean_path, file_name), + "noisy": os.path.join(noisy_path, file_name), + "duration": clean_file.shape[-1] / sr_clean, + } + ) return matching_wavfiles @staticmethod - def one_to_many(clean_path,noisy_path): + def one_to_many(clean_path, noisy_path): """ One clean audio have multiple noisy audio files """ - + matching_wavfiles = dict() - clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] + clean_filenames = [ + file.split("/")[-1] + for file in glob.glob(os.path.join(clean_path, "*.wav")) + ] for clean_file in clean_filenames: - noisy_filenames = glob.glob(os.path.join(noisy_path,f"*_{clean_file}.wav")) + noisy_filenames = glob.glob( + os.path.join(noisy_path, f"*_{clean_file}.wav") + ) for noisy_file in noisy_filenames: - sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) + sr_clean, clean_file = wavfile.read( + os.path.join(clean_path, clean_file) + ) sr_noisy, noisy_file = wavfile.read(noisy_file) - if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr_noisy)): + if (clean_file.shape[-1] == noisy_file.shape[-1]) and ( + sr_clean == sr_noisy + ): matching_wavfiles.update( - {"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, - "duration":clean_file.shape[-1]/sr_clean} - ) + { + "clean": os.path.join(clean_path, clean_file), + "noisy": noisy_file, + "duration": clean_file.shape[-1] / sr_clean, + } + ) return matching_wavfiles class Fileprocessor: - - def __init__( - self, - clean_dir, - noisy_dir, - matching_function = None - ): + def __init__(self, clean_dir, noisy_dir, matching_function=None): self.clean_dir = clean_dir self.noisy_dir = noisy_dir self.matching_function = matching_function @classmethod - def from_name(cls, - name:str, - clean_dir, - noisy_dir, - matching_function=None - ): + def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None): if matching_function is None: if name.lower() == "vctk": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_one) + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one) elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir, ProcessorFunctions.one_to_many) + return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many) else: if matching_function not in MATCHING_FNS: - raise ValueError(F"Invalid matching function! Avaialble options are {MATCHING_FNS}") + raise ValueError( + f"Invalid matching function! Avaialble options are {MATCHING_FNS}" + ) else: - return cls(clean_dir,noisy_dir, getattr(ProcessorFunctions,matching_function)) - - + return cls( + clean_dir, + noisy_dir, + getattr(ProcessorFunctions, matching_function), + ) def prepare_matching_dict(self): if self.matching_function is None: raise ValueError("Not a valid matching function") - return self.matching_function(self.clean_dir,self.noisy_dir) - - - - - - - - - - - - + return self.matching_function(self.clean_dir, self.noisy_dir)