123 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			123 lines
		
	
	
		
			4.0 KiB
		
	
	
	
		
			Python
		
	
	
	
import glob
 | 
						|
import os
 | 
						|
 | 
						|
import numpy as np
 | 
						|
from scipy.io import wavfile
 | 
						|
 | 
						|
MATCHING_FNS = ("one_to_one", "one_to_many")
 | 
						|
 | 
						|
 | 
						|
class ProcessorFunctions:
 | 
						|
    """
 | 
						|
    Preprocessing methods for different types of speech enhacement datasets.
 | 
						|
    """
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def one_to_one(clean_path, noisy_path):
 | 
						|
        """
 | 
						|
        One clean audio can have only one noisy audio file
 | 
						|
        """
 | 
						|
 | 
						|
        matching_wavfiles = list()
 | 
						|
        clean_filenames = [
 | 
						|
            file.split("/")[-1]
 | 
						|
            for file in glob.glob(os.path.join(clean_path, "*.wav"))
 | 
						|
        ]
 | 
						|
        noisy_filenames = [
 | 
						|
            file.split("/")[-1]
 | 
						|
            for file in glob.glob(os.path.join(noisy_path, "*.wav"))
 | 
						|
        ]
 | 
						|
        common_filenames = np.intersect1d(noisy_filenames, clean_filenames)
 | 
						|
 | 
						|
        for file_name in common_filenames:
 | 
						|
 | 
						|
            sr_clean, clean_file = wavfile.read(
 | 
						|
                os.path.join(clean_path, file_name)
 | 
						|
            )
 | 
						|
            sr_noisy, noisy_file = wavfile.read(
 | 
						|
                os.path.join(noisy_path, file_name)
 | 
						|
            )
 | 
						|
            if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
 | 
						|
                sr_clean == sr_noisy
 | 
						|
            ):
 | 
						|
                matching_wavfiles.append(
 | 
						|
                    {
 | 
						|
                        "clean": os.path.join(clean_path, file_name),
 | 
						|
                        "noisy": os.path.join(noisy_path, file_name),
 | 
						|
                        "duration": clean_file.shape[-1] / sr_clean,
 | 
						|
                    }
 | 
						|
                )
 | 
						|
        return matching_wavfiles
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def one_to_many(clean_path, noisy_path):
 | 
						|
        """
 | 
						|
        One clean audio have multiple noisy audio files
 | 
						|
        """
 | 
						|
 | 
						|
        matching_wavfiles = list()
 | 
						|
        clean_filenames = [
 | 
						|
            file.split("/")[-1]
 | 
						|
            for file in glob.glob(os.path.join(clean_path, "*.wav"))
 | 
						|
        ]
 | 
						|
        for clean_file in clean_filenames:
 | 
						|
            noisy_filenames = glob.glob(
 | 
						|
                os.path.join(noisy_path, f"*_{clean_file}.wav")
 | 
						|
            )
 | 
						|
            for noisy_file in noisy_filenames:
 | 
						|
 | 
						|
                sr_clean, clean_file = wavfile.read(
 | 
						|
                    os.path.join(clean_path, clean_file)
 | 
						|
                )
 | 
						|
                sr_noisy, noisy_file = wavfile.read(noisy_file)
 | 
						|
                if (clean_file.shape[-1] == noisy_file.shape[-1]) and (
 | 
						|
                    sr_clean == sr_noisy
 | 
						|
                ):
 | 
						|
                    matching_wavfiles.append(
 | 
						|
                        {
 | 
						|
                            "clean": os.path.join(clean_path, clean_file),
 | 
						|
                            "noisy": noisy_file,
 | 
						|
                            "duration": clean_file.shape[-1] / sr_clean,
 | 
						|
                        }
 | 
						|
                    )
 | 
						|
 | 
						|
        return matching_wavfiles
 | 
						|
 | 
						|
 | 
						|
class Fileprocessor:
 | 
						|
    def __init__(self, clean_dir, noisy_dir, matching_function=None):
 | 
						|
        self.clean_dir = clean_dir
 | 
						|
        self.noisy_dir = noisy_dir
 | 
						|
        self.matching_function = matching_function
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_name(cls, name: str, clean_dir, noisy_dir, matching_function=None):
 | 
						|
 | 
						|
        if matching_function is None:
 | 
						|
            if name.lower() == "vctk":
 | 
						|
                return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_one)
 | 
						|
            elif name.lower() == "dns-2020":
 | 
						|
                return cls(clean_dir, noisy_dir, ProcessorFunctions.one_to_many)
 | 
						|
            else:
 | 
						|
                raise ValueError(
 | 
						|
                    f"Invalid matching function, Please use valid matching function from {MATCHING_FNS}"
 | 
						|
                )
 | 
						|
        else:
 | 
						|
            if matching_function not in MATCHING_FNS:
 | 
						|
                raise ValueError(
 | 
						|
                    f"Invalid matching function! Avaialble options are {MATCHING_FNS}"
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                return cls(
 | 
						|
                    clean_dir,
 | 
						|
                    noisy_dir,
 | 
						|
                    getattr(ProcessorFunctions, matching_function),
 | 
						|
                )
 | 
						|
 | 
						|
    def prepare_matching_dict(self):
 | 
						|
 | 
						|
        if self.matching_function is None:
 | 
						|
            raise ValueError("Not a valid matching function")
 | 
						|
 | 
						|
        return self.matching_function(self.clean_dir, self.noisy_dir)
 |