From 6c4bced3607f23db76d911f9a3471de22db53fa1 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 5 Oct 2022 15:20:43 +0530 Subject: [PATCH] reformat utils --- enhancer/utils/__init__.py | 2 +- enhancer/utils/config.py | 11 +++++------ enhancer/utils/random.py | 29 ++++++++++++----------------- enhancer/utils/utils.py | 21 +++++++++++++-------- 4 files changed, 31 insertions(+), 32 deletions(-) diff --git a/enhancer/utils/__init__.py b/enhancer/utils/__init__.py index 3da7ede..c9f5438 100644 --- a/enhancer/utils/__init__.py +++ b/enhancer/utils/__init__.py @@ -1,3 +1,3 @@ from enhancer.utils.utils import check_files from enhancer.utils.io import Audio -from enhancer.utils.config import Files \ No newline at end of file +from enhancer.utils.config import Files diff --git a/enhancer/utils/config.py b/enhancer/utils/config.py index 1bbc51d..252e6c9 100644 --- a/enhancer/utils/config.py +++ b/enhancer/utils/config.py @@ -1,10 +1,9 @@ from dataclasses import dataclass + @dataclass class Files: - train_clean : str - train_noisy : str - test_clean : str - test_noisy : str - - + train_clean: str + train_noisy: str + test_clean: str + test_noisy: str diff --git a/enhancer/utils/random.py b/enhancer/utils/random.py index 3b1acac..51e09c0 100644 --- a/enhancer/utils/random.py +++ b/enhancer/utils/random.py @@ -3,17 +3,16 @@ import random import torch - -def create_unique_rng(epoch:int): +def create_unique_rng(epoch: int): """create unique random number generator for each (worker_id,epoch) combination""" rng = random.Random() - global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) - global_rank = int(os.environ.get('GLOBAL_RANK',"0")) - local_rank = int(os.environ.get('LOCAL_RANK',"0")) - node_rank = int(os.environ.get('NODE_RANK',"0")) - world_size = int(os.environ.get('WORLD_SIZE',"0")) + global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0")) + global_rank = int(os.environ.get("GLOBAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + node_rank = int(os.environ.get("NODE_RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "0")) worker_info = torch.utils.data.get_worker_info() if worker_info is not None: @@ -24,17 +23,13 @@ def create_unique_rng(epoch:int): worker_id = 0 seed = ( - global_seed - + worker_id - + local_rank * num_workers - + node_rank * num_workers * global_rank - + epoch * num_workers * world_size - ) + global_seed + + worker_id + + local_rank * num_workers + + node_rank * num_workers * global_rank + + epoch * num_workers * world_size + ) rng.seed(seed) return rng - - - - diff --git a/enhancer/utils/utils.py b/enhancer/utils/utils.py index be74dc2..73673ed 100644 --- a/enhancer/utils/utils.py +++ b/enhancer/utils/utils.py @@ -1,19 +1,24 @@ - import os from typing import Optional from enhancer.utils.config import Files -def check_files(root_dir:str, files:Files): - path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] +def check_files(root_dir: str, files: Files): + + path_variables = [ + member_var + for member_var in dir(files) + if not member_var.startswith("__") + ] for variable in path_variables: - path = getattr(files,variable) - if not os.path.isdir(os.path.join(root_dir,path)): + path = getattr(files, variable) + if not os.path.isdir(os.path.join(root_dir, path)): raise ValueError(f"Invalid {path}, is not a directory") - - return files,root_dir -def merge_dict(default_dict:dict, custom:Optional[dict]=None): + return files, root_dir + + +def merge_dict(default_dict: dict, custom: Optional[dict] = None): params = dict(default_dict) if custom: params.update(custom)