hparams.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import argparse
  2. import json
  3. import os
  4. import yaml
  5. from utils.commons.os_utils import remove_file
  6. global_print_hparams = True
  7. hparams = {}
  8. class Args:
  9. def __init__(self, **kwargs):
  10. for k, v in kwargs.items():
  11. self.__setattr__(k, v)
  12. def override_config(old_config: dict, new_config: dict):
  13. if new_config.get('__replace', False):
  14. old_config.clear()
  15. for k, v in new_config.items():
  16. if isinstance(v, dict) and k in old_config:
  17. override_config(old_config[k], new_config[k])
  18. else:
  19. old_config[k] = v
  20. def traverse_dict(d, func):
  21. for k in d.keys():
  22. v = d[k]
  23. if isinstance(v, dict):
  24. traverse_dict(v, func)
  25. else:
  26. d[k] = func(v)
  27. def parse_config_ref(v):
  28. if isinstance(v, str) and v.startswith('^'):
  29. return load_config(v[1:], [], set())
  30. return v
  31. def remove_meta_key(d):
  32. for k in list(d.keys()):
  33. v = d[k]
  34. if isinstance(v, dict):
  35. remove_meta_key(v)
  36. else:
  37. if k[:2] == '__':
  38. del d[k]
  39. def load_config(config_fn, config_chains, loaded_configs):
  40. # deep first inheritance and avoid the second visit of one node
  41. if not os.path.exists(config_fn):
  42. print(f"| WARN: {config_fn} not exist.", )
  43. return {}
  44. with open(config_fn) as f:
  45. hparams_ = yaml.safe_load(f)
  46. loaded_configs.add(config_fn)
  47. traverse_dict(hparams_, parse_config_ref)
  48. if 'base_config' in hparams_:
  49. ret_hparams = {}
  50. if not isinstance(hparams_['base_config'], list):
  51. hparams_['base_config'] = [hparams_['base_config']]
  52. for c in hparams_['base_config']:
  53. if c.startswith('.'):
  54. c = f'{os.path.dirname(config_fn)}/{c}'
  55. c = os.path.normpath(c)
  56. if c not in loaded_configs:
  57. override_config(ret_hparams, load_config(c, config_chains, loaded_configs))
  58. override_config(ret_hparams, hparams_)
  59. else:
  60. ret_hparams = hparams_
  61. config_chains.append(config_fn)
  62. return ret_hparams
  63. def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
  64. if config == '' and exp_name == '':
  65. parser = argparse.ArgumentParser(description='')
  66. parser.add_argument('--config', type=str, default='',
  67. help='location of the data corpus')
  68. parser.add_argument('--exp_name', type=str, default='', help='exp_name')
  69. parser.add_argument('-hp', '--hparams', type=str, default='',
  70. help='location of the data corpus')
  71. parser.add_argument('--infer', action='store_true', help='infer')
  72. parser.add_argument('--validate', action='store_true', help='validate')
  73. parser.add_argument('--reset', action='store_true', help='reset hparams')
  74. parser.add_argument('--remove', action='store_true', help='remove old ckpt')
  75. parser.add_argument('--debug', action='store_true', help='debug')
  76. parser.add_argument('--start_rank', type=int, default=0, help='the start rank id for DDP, keep 0 when single-machine multi-GPU')
  77. parser.add_argument('--world_size', type=int, default=-1, help='the total number of GPU used across all machines, keep -1 for single-machine multi-GPU')
  78. parser.add_argument('--init_method', type=str, default='tcp', help='method to init ddp, use tcp or file')
  79. args, unknown = parser.parse_known_args()
  80. if print_hparams:
  81. print("| set_hparams Unknow hparams: ", unknown)
  82. else:
  83. args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
  84. infer=False, validate=False, reset=False, debug=False, remove=False, start_rank=0, world_size=-1, init_method='tcp')
  85. global hparams
  86. assert args.config != '' or args.exp_name != ''
  87. if args.config != '':
  88. assert os.path.exists(args.config), args.config
  89. saved_hparams = {}
  90. args_work_dir = ''
  91. if args.exp_name != '':
  92. args_work_dir = f'checkpoints/{args.exp_name}'
  93. ckpt_config_path = f'{args_work_dir}/config.yaml'
  94. if os.path.exists(ckpt_config_path):
  95. with open(ckpt_config_path) as f:
  96. saved_hparams_ = yaml.safe_load(f)
  97. if saved_hparams_ is not None:
  98. saved_hparams.update(saved_hparams_)
  99. hparams_ = {}
  100. config_chains = []
  101. if args.config != '':
  102. hparams_.update(load_config(args.config, config_chains, set()))
  103. if len(config_chains) > 1 and print_hparams:
  104. print('| Hparams chains: ', config_chains)
  105. if not args.reset:
  106. hparams_.update(saved_hparams)
  107. hparams_['work_dir'] = args_work_dir
  108. # Support config overriding in command line. Support list type config overriding.
  109. # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
  110. if args.hparams != "":
  111. for new_hparam in args.hparams.split(","):
  112. k, v = new_hparam.split("=")
  113. v = v.strip("\'\" ")
  114. config_node = hparams_
  115. for k_ in k.split(".")[:-1]:
  116. config_node = config_node[k_]
  117. k = k.split(".")[-1]
  118. if k in config_node:
  119. if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
  120. if type(config_node[k]) == list:
  121. v = v.replace(" ", ",").replace('^', "\"")
  122. if '|' in v:
  123. tp = type(config_node[k][0]) if len(config_node[k]) else str
  124. config_node[k] = [tp(x) for x in v.split("|") if x != '']
  125. continue
  126. config_node[k] = eval(v)
  127. else:
  128. config_node[k] = type(config_node[k])(v)
  129. else:
  130. config_node[k] = v
  131. try:
  132. config_node[k] = float(v)
  133. except:
  134. pass
  135. try:
  136. config_node[k] = int(v)
  137. except:
  138. pass
  139. if v.lower() in ['false', 'true']:
  140. config_node[k] = v.lower() == 'true'
  141. if args_work_dir != '' and args.remove:
  142. answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
  143. if answer.lower() == "y":
  144. remove_file(args_work_dir)
  145. if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
  146. os.makedirs(hparams_['work_dir'], exist_ok=True)
  147. with open(ckpt_config_path, 'w') as f:
  148. yaml.safe_dump(hparams_, f)
  149. hparams_['infer'] = args.infer
  150. hparams_['debug'] = args.debug
  151. hparams_['validate'] = args.validate
  152. hparams_['exp_name'] = args.exp_name
  153. hparams_['start_rank'] = args.start_rank # useful for multi-machine training
  154. hparams_['world_size'] = args.world_size
  155. hparams_['init_method'] = args.init_method
  156. remove_meta_key(hparams_)
  157. global global_print_hparams
  158. if global_hparams:
  159. hparams.clear()
  160. hparams.update(hparams_)
  161. if print_hparams and global_print_hparams and global_hparams:
  162. print('| Hparams: ', json.dumps(hparams_, indent=2, sort_keys=True))
  163. # for i, (k, v) in enumerate(sorted(hparams_.items())):
  164. # print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
  165. global_print_hparams = False
  166. return hparams_
  167. if __name__ == '__main__':
  168. set_hparams('checkpoints/1205_os_secc2planes/os_secc2plane_trigridv2/config.yaml')