reformat utils

This commit is contained in:
shahules786 2022-10-05 15:20:43 +05:30
parent aca4521ef2
commit 6c4bced360
4 changed files with 31 additions and 32 deletions

View File

@ -1,3 +1,3 @@
from enhancer.utils.utils import check_files from enhancer.utils.utils import check_files
from enhancer.utils.io import Audio from enhancer.utils.io import Audio
from enhancer.utils.config import Files from enhancer.utils.config import Files

View File

@ -1,10 +1,9 @@
from dataclasses import dataclass from dataclasses import dataclass
@dataclass @dataclass
class Files: class Files:
train_clean : str train_clean: str
train_noisy : str train_noisy: str
test_clean : str test_clean: str
test_noisy : str test_noisy: str

View File

@ -3,17 +3,16 @@ import random
import torch import torch
def create_unique_rng(epoch: int):
def create_unique_rng(epoch:int):
"""create unique random number generator for each (worker_id,epoch) combination""" """create unique random number generator for each (worker_id,epoch) combination"""
rng = random.Random() rng = random.Random()
global_seed = int(os.environ.get("PL_GLOBAL_SEED","0")) global_seed = int(os.environ.get("PL_GLOBAL_SEED", "0"))
global_rank = int(os.environ.get('GLOBAL_RANK',"0")) global_rank = int(os.environ.get("GLOBAL_RANK", "0"))
local_rank = int(os.environ.get('LOCAL_RANK',"0")) local_rank = int(os.environ.get("LOCAL_RANK", "0"))
node_rank = int(os.environ.get('NODE_RANK',"0")) node_rank = int(os.environ.get("NODE_RANK", "0"))
world_size = int(os.environ.get('WORLD_SIZE',"0")) world_size = int(os.environ.get("WORLD_SIZE", "0"))
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: if worker_info is not None:
@ -24,17 +23,13 @@ def create_unique_rng(epoch:int):
worker_id = 0 worker_id = 0
seed = ( seed = (
global_seed global_seed
+ worker_id + worker_id
+ local_rank * num_workers + local_rank * num_workers
+ node_rank * num_workers * global_rank + node_rank * num_workers * global_rank
+ epoch * num_workers * world_size + epoch * num_workers * world_size
) )
rng.seed(seed) rng.seed(seed)
return rng return rng

View File

@ -1,19 +1,24 @@
import os import os
from typing import Optional from typing import Optional
from enhancer.utils.config import Files from enhancer.utils.config import Files
def check_files(root_dir:str, files:Files):
path_variables = [member_var for member_var in dir(files) if not member_var.startswith('__')] def check_files(root_dir: str, files: Files):
path_variables = [
member_var
for member_var in dir(files)
if not member_var.startswith("__")
]
for variable in path_variables: for variable in path_variables:
path = getattr(files,variable) path = getattr(files, variable)
if not os.path.isdir(os.path.join(root_dir,path)): if not os.path.isdir(os.path.join(root_dir, path)):
raise ValueError(f"Invalid {path}, is not a directory") raise ValueError(f"Invalid {path}, is not a directory")
return files,root_dir
def merge_dict(default_dict:dict, custom:Optional[dict]=None): return files, root_dir
def merge_dict(default_dict: dict, custom: Optional[dict] = None):
params = dict(default_dict) params = dict(default_dict)
if custom: if custom:
params.update(custom) params.update(custom)