diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index 98abe8a..5749c36 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -66,15 +66,13 @@ class TaskDataset(pl.LightningDataModule): train_clean = os.path.join(self.root_dir,self.files.train_clean) train_noisy = os.path.join(self.root_dir,self.files.train_noisy) fp = Fileprocessor.from_name(self.name,train_clean, - train_noisy,self.sampling_rate, - self.matching_function) + train_noisy, self.matching_function) self.train_data = fp.prepare_matching_dict() val_clean = os.path.join(self.root_dir,self.files.test_clean) val_noisy = os.path.join(self.root_dir,self.files.test_noisy) fp = Fileprocessor.from_name(self.name,val_clean, - val_noisy,self.sampling_rate, - self.matching_function) + val_noisy, self.matching_function) val_data = fp.prepare_matching_dict() for item in val_data: diff --git a/enhancer/data/fileprocessor.py b/enhancer/data/fileprocessor.py index 4df3e23..f903375 100644 --- a/enhancer/data/fileprocessor.py +++ b/enhancer/data/fileprocessor.py @@ -1,12 +1,13 @@ import glob import os +from re import S import numpy as np from scipy.io import wavfile class ProcessorFunctions: @staticmethod - def match_vtck(clean_path,noisy_path,sr): + def match_vtck(clean_path,noisy_path): matching_wavfiles = list() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -18,16 +19,15 @@ class ProcessorFunctions: sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr) and - (sr_noisy==sr)): + (sr_clean==sr_noisy)): matching_wavfiles.append( {"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), - "duration":clean_file.shape[-1]/sr} + "duration":clean_file.shape[-1]/sr_clean} ) return matching_wavfiles @staticmethod - def match_dns2020(clean_path,noisy_path,sr): + def match_dns2020(clean_path,noisy_path): matching_wavfiles = dict() clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] @@ -38,11 +38,10 @@ class ProcessorFunctions: sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) sr_noisy, noisy_file = wavfile.read(noisy_file) if ((clean_file.shape[-1]==noisy_file.shape[-1]) and - (sr_clean==sr) and - (sr_noisy==sr)): + (sr_clean==sr_noisy)): matching_wavfiles.update( {"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, - "duration":clean_file.shape[-1]/sr} + "duration":clean_file.shape[-1]/sr_clean} ) return matching_wavfiles @@ -54,12 +53,10 @@ class Fileprocessor: self, clean_dir, noisy_dir, - sr = 16000, matching_function = None ): self.clean_dir = clean_dir self.noisy_dir = noisy_dir - self.sr = sr self.matching_function = matching_function @classmethod @@ -67,23 +64,22 @@ class Fileprocessor: name:str, clean_dir, noisy_dir, - sr, matching_function=None ): if name.lower() == "vctk": - return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_vtck) + return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck) elif name.lower() == "dns-2020": - return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_dns2020) + return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020) else: - return cls(clean_dir,noisy_dir,sr, matching_function) + return cls(clean_dir,noisy_dir, matching_function) def prepare_matching_dict(self): if self.matching_function is None: raise ValueError("Not a valid matching function") - return self.matching_function(self.clean_dir,self.noisy_dir,self.sr) + return self.matching_function(self.clean_dir,self.noisy_dir)