123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import argparse
- import json
- import os
- import yaml
- from utils.commons.os_utils import remove_file
- global_print_hparams = True
- hparams = {}
- class Args:
- def __init__(self, **kwargs):
- for k, v in kwargs.items():
- self.__setattr__(k, v)
- def override_config(old_config: dict, new_config: dict):
- if new_config.get('__replace', False):
- old_config.clear()
- for k, v in new_config.items():
- if isinstance(v, dict) and k in old_config:
- override_config(old_config[k], new_config[k])
- else:
- old_config[k] = v
- def traverse_dict(d, func):
- for k in d.keys():
- v = d[k]
- if isinstance(v, dict):
- traverse_dict(v, func)
- else:
- d[k] = func(v)
- def parse_config_ref(v):
- if isinstance(v, str) and v.startswith('^'):
- return load_config(v[1:], [], set())
- return v
- def remove_meta_key(d):
- for k in list(d.keys()):
- v = d[k]
- if isinstance(v, dict):
- remove_meta_key(v)
- else:
- if k[:2] == '__':
- del d[k]
- def load_config(config_fn, config_chains, loaded_configs):
- # deep first inheritance and avoid the second visit of one node
- if not os.path.exists(config_fn):
- print(f"| WARN: {config_fn} not exist.", )
- return {}
- with open(config_fn) as f:
- hparams_ = yaml.safe_load(f)
- loaded_configs.add(config_fn)
- traverse_dict(hparams_, parse_config_ref)
- if 'base_config' in hparams_:
- ret_hparams = {}
- if not isinstance(hparams_['base_config'], list):
- hparams_['base_config'] = [hparams_['base_config']]
- for c in hparams_['base_config']:
- if c.startswith('.'):
- c = f'{os.path.dirname(config_fn)}/{c}'
- c = os.path.normpath(c)
- if c not in loaded_configs:
- override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
- override_config(ret_hparams, hparams_)
- else:
- ret_hparams = hparams_
- config_chains.append(config_fn)
- return ret_hparams
- def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
- if config == '' and exp_name == '':
- parser = argparse.ArgumentParser(description='')
- parser.add_argument('--config', type=str, default='',
- help='location of the data corpus')
- parser.add_argument('--exp_name', type=str, default='', help='exp_name')
- parser.add_argument('-hp', '--hparams', type=str, default='',
- help='location of the data corpus')
- parser.add_argument('--infer', action='store_true', help='infer')
- parser.add_argument('--validate', action='store_true', help='validate')
- parser.add_argument('--reset', action='store_true', help='reset hparams')
- parser.add_argument('--remove', action='store_true', help='remove old ckpt')
- parser.add_argument('--debug', action='store_true', help='debug')
- parser.add_argument('--start_rank', type=int, default=0, help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
- parser.add_argument('--world_size', type=int, default=-1, help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
- parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
- args, unknown = parser.parse_known_args()
- if print_hparams:
- print("| set_hparams Unknow hparams: ", unknown)
- else:
- args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
- infer=False, validate=False, reset=False, debug=False, remove=False, start_rank=0, world_size=-1, init_method='tcp')
- global hparams
- assert args.config != '' or args.exp_name != ''
- if args.config != '':
- assert os.path.exists(args.config), args.config
- saved_hparams = {}
- args_work_dir = ''
- if args.exp_name != '':
- args_work_dir = f'checkpoints/{args.exp_name}'
- ckpt_config_path = f'{args_work_dir}/config.yaml'
- if os.path.exists(ckpt_config_path):
- with open(ckpt_config_path) as f:
- saved_hparams_ = yaml.safe_load(f)
- if saved_hparams_ is not None:
- saved_hparams.update(saved_hparams_)
- hparams_ = {}
- config_chains = []
- if args.config != '':
- hparams_.update(load_config(args.config, config_chains, set()))
- if len(config_chains) > 1 and print_hparams:
- print('| Hparams chains: ', config_chains)
- if not args.reset:
- hparams_.update(saved_hparams)
- hparams_['work_dir'] = args_work_dir
- # Support config overriding in command line. Support list type config overriding.
- # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
- if args.hparams != "":
- for new_hparam in args.hparams.split(","):
- k, v = new_hparam.split("=")
- v = v.strip("\'\" ")
- config_node = hparams_
- for k_ in k.split(".")[:-1]:
- config_node = config_node[k_]
- k = k.split(".")[-1]
- if k in config_node:
- if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
- if type(config_node[k]) == list:
- v = v.replace(" ", ",").replace('^', "\"")
- if '|' in v:
- tp = type(config_node[k][0]) if len(config_node[k]) else str
- config_node[k] = [tp(x) for x in v.split("|") if x != '']
- continue
- config_node[k] = eval(v)
- else:
- config_node[k] = type(config_node[k])(v)
- else:
- config_node[k] = v
- try:
- config_node[k] = float(v)
- except:
- pass
- try:
- config_node[k] = int(v)
- except:
- pass
- if v.lower() in ['false', 'true']:
- config_node[k] = v.lower() == 'true'
- if args_work_dir != '' and args.remove:
- answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
- if answer.lower() == "y":
- remove_file(args_work_dir)
- if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
- os.makedirs(hparams_['work_dir'], exist_ok=True)
- with open(ckpt_config_path, 'w') as f:
- yaml.safe_dump(hparams_, f)
- hparams_['infer'] = args.infer
- hparams_['debug'] = args.debug
- hparams_['validate'] = args.validate
- hparams_['exp_name'] = args.exp_name
- hparams_['start_rank'] = args.start_rank # useful for multi-machine training
- hparams_['world_size'] = args.world_size
- hparams_['init_method'] = args.init_method
- remove_meta_key(hparams_)
- global global_print_hparams
- if global_hparams:
- hparams.clear()
- hparams.update(hparams_)
- if print_hparams and global_print_hparams and global_hparams:
- print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
- # for i, (k, v) in enumerate(sorted(hparams_.items())):
- # print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
- global_print_hparams = False
- return hparams_
- if __name__ == '__main__':
- set_hparams('checkpoints/1205_os_secc2planes/os_secc2plane_trigridv2/config.yaml')
|