config_utils.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. """
  2. Copyright (c) Microsoft Corporation
  3. Licensed under the MIT license.
  4. """
  5. """
  6. Collection of DeepSpeed configuration utilities
  7. """
  8. import json
  9. import collections
  10. # adapted from https://stackoverflow.com/a/50701137/9201239
  11. class ScientificNotationEncoder(json.JSONEncoder):
  12. """
  13. This class overrides ``json.dumps`` default formatter.
  14. This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation.
  15. Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it
  16. """
  17. def iterencode(self, o, _one_shot=False, level=0):
  18. indent = self.indent if self.indent is not None else 4
  19. prefix_close = " " * level * indent
  20. level += 1
  21. prefix = " " * level * indent
  22. if isinstance(o, bool):
  23. return "true" if o else "false"
  24. elif isinstance(o, float) or isinstance(o, int):
  25. if o > 1e3:
  26. return f"{o:e}"
  27. else:
  28. return f"{o}"
  29. elif isinstance(o, collections.Mapping):
  30. x = [
  31. f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k,
  32. v in o.items()
  33. ]
  34. return "{" + ', '.join(x) + f"\n{prefix_close}" + "}"
  35. elif isinstance(o, collections.Sequence) and not isinstance(o, str):
  36. return f"[{ f', '.join(map(self.iterencode, o)) }]"
  37. return "\n, ".join(super().iterencode(o, _one_shot))
  38. class DeepSpeedConfigObject(object):
  39. """
  40. For json serialization
  41. """
  42. def repr(self):
  43. return self.__dict__
  44. def __repr__(self):
  45. return json.dumps(
  46. self.__dict__,
  47. sort_keys=True,
  48. indent=4,
  49. cls=ScientificNotationEncoder,
  50. )
  51. def get_scalar_param(param_dict, param_name, param_default_value):
  52. return param_dict.get(param_name, param_default_value)
  53. def get_list_param(param_dict, param_name, param_default_value):
  54. return param_dict.get(param_name, param_default_value)
  55. def get_dict_param(param_dict, param_name, param_default_value):
  56. return param_dict.get(param_name, param_default_value)
  57. def dict_raise_error_on_duplicate_keys(ordered_pairs):
  58. """Reject duplicate keys."""
  59. d = dict((k, v) for k, v in ordered_pairs)
  60. if len(d) != len(ordered_pairs):
  61. counter = collections.Counter([pair[0] for pair in ordered_pairs])
  62. keys = [key for key, value in counter.items() if value > 1]
  63. raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys))
  64. return d