config_utils.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. Collection of DeepSpeed configuration utilities
  6. """
  7. import json
  8. import collections
  9. import collections.abc
  10. from functools import reduce
  11. from deepspeed.pydantic_v1 import BaseModel
  12. from deepspeed.utils import logger
  13. class DeepSpeedConfigModel(BaseModel):
  14. """
  15. This class should be used as a base for all DeepSpeed configs. It extends
  16. pydantic.BaseModel to allow for deprecated fields. To enable this feature,
  17. add deprecated=True to pydantic.Field:
  18. my_dep_field: int = Field(0, deprecated=True)
  19. Deprecated Field kwargs:
  20. - deprecated: [True|False], default False
  21. Enables / Disables deprecated fields
  22. - deprecated_msg: str, default ""
  23. Message to include with deprecation warning
  24. - new_param: str, default ""
  25. Name of the field replacing the deprecated field
  26. - set_new_param: [True|False], default True
  27. If new_param is provided, enables setting the value of that param with
  28. deprecated field value
  29. - new_param_fn: callable, default (lambda x: x)
  30. If new_param is provided and set_new_param is True, this function will
  31. modify the value of the deprecated field before placing that value in
  32. the new_param field
  33. Example:
  34. my_new_field is replacing a deprecated my_old_field. The expected type
  35. for my_new_field is int while the expected type for my_old_field is
  36. str. We want to maintain backward compatibility with our configs, so we
  37. define the fields with:
  38. class MyExampleConfig(DeepSpeedConfigModel):
  39. my_new_field: int = 0
  40. my_old_field: str = Field('0',
  41. deprecated=True,
  42. new_param='my_new_field',
  43. new_param_fn=(lambda x: int(x)))
  44. """
  45. def __init__(self, strict=False, **data):
  46. if (not strict): # This is temporary until we refactor all DS configs, allows HF to load models
  47. data = {k: v for k, v in data.items() if (v != "auto" or k == "replace_method")}
  48. super().__init__(**data)
  49. self._deprecated_fields_check(self)
  50. def _process_deprecated_field(self, pydantic_config, field):
  51. # Get information about the deprecated field
  52. fields_set = pydantic_config.__fields_set__
  53. dep_param = field.name
  54. kwargs = field.field_info.extra
  55. new_param_fn = kwargs.get("new_param_fn", lambda x: x)
  56. param_value = new_param_fn(getattr(pydantic_config, dep_param))
  57. new_param = kwargs.get("new_param", "")
  58. dep_msg = kwargs.get("deprecated_msg", "")
  59. if dep_param in fields_set:
  60. logger.warning(f"Config parameter {dep_param} is deprecated" +
  61. (f" use {new_param} instead" if new_param else "") + (f". {dep_msg}" if dep_msg else ""))
  62. # Check if there is a new param and if it should be set with a value
  63. if new_param and kwargs.get("set_new_param", True):
  64. # Remove the deprecate field if there is a replacing field
  65. try:
  66. delattr(pydantic_config, dep_param)
  67. except Exception as e:
  68. logger.error(f"Tried removing deprecated '{dep_param}' from config")
  69. raise e
  70. # Set new param value
  71. new_param_nested = new_param.split(".")
  72. if len(new_param_nested) > 1:
  73. # If the new param exists in a subconfig, we need to get
  74. # the fields set for that subconfig
  75. pydantic_config = reduce(getattr, new_param_nested[:-1], pydantic_config)
  76. fields_set = pydantic_config.__fields_set__
  77. new_param_name = new_param_nested[-1]
  78. assert (
  79. new_param_name not in fields_set
  80. ), f"Cannot provide deprecated parameter '{dep_param}' and replacing parameter '{new_param}' together"
  81. # A custom function for converting the old param value to new param value can be provided
  82. try:
  83. setattr(pydantic_config, new_param_name, param_value)
  84. except Exception as e:
  85. logger.error(f"Tried setting value for '{new_param}' with value from deprecated '{dep_param}'")
  86. raise e
  87. def _deprecated_fields_check(self, pydantic_config):
  88. fields = pydantic_config.__fields__
  89. for field in fields.values():
  90. if field.field_info.extra.get("deprecated", False):
  91. self._process_deprecated_field(pydantic_config, field)
  92. class Config:
  93. validate_all = True
  94. validate_assignment = True
  95. use_enum_values = True
  96. allow_population_by_field_name = True
  97. extra = "forbid"
  98. arbitrary_types_allowed = True
  99. def get_config_default(config, field_name):
  100. assert field_name in config.__fields__, f"'{field_name}' is not a field in {config}"
  101. assert not config.__fields__.get(
  102. field_name).required, f"'{field_name}' is a required field and does not have a default value"
  103. return config.__fields__.get(field_name).default
  104. class pp_int(int):
  105. """
  106. A wrapper for integers that will return a custom string or comma-formatted
  107. string of the integer. For example, print(pp_int(1e5)) will return
  108. "10,000". This is useful mainly for auto-generated documentation purposes.
  109. """
  110. def __new__(cls, val, custom_print_str=None):
  111. inst = super().__new__(cls, val)
  112. inst.custom_print_str = custom_print_str
  113. return inst
  114. def __repr__(self):
  115. if self.custom_print_str:
  116. return self.custom_print_str
  117. return f"{self.real:,}"
  118. # adapted from https://stackoverflow.com/a/50701137/9201239
  119. class ScientificNotationEncoder(json.JSONEncoder):
  120. """
  121. This class overrides ``json.dumps`` default formatter.
  122. This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation.
  123. Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it
  124. """
  125. def iterencode(self, o, _one_shot=False, level=0):
  126. indent = self.indent if self.indent is not None else 4
  127. prefix_close = " " * level * indent
  128. level += 1
  129. prefix = " " * level * indent
  130. if isinstance(o, bool):
  131. return "true" if o else "false"
  132. elif isinstance(o, float) or isinstance(o, int):
  133. if o > 1e3:
  134. return f"{o:e}"
  135. else:
  136. return f"{o}"
  137. elif isinstance(o, collections.abc.Mapping):
  138. x = [f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, v in o.items()]
  139. return "{" + ", ".join(x) + f"\n{prefix_close}" + "}"
  140. elif isinstance(o, collections.abc.Sequence) and not isinstance(o, str):
  141. return f"[{ f', '.join(map(self.iterencode, o)) }]"
  142. return "\n, ".join(super().iterencode(o, _one_shot))
  143. class DeepSpeedConfigObject(object):
  144. """
  145. For json serialization
  146. """
  147. def repr(self):
  148. return self.__dict__
  149. def __repr__(self):
  150. return json.dumps(
  151. self.__dict__,
  152. sort_keys=True,
  153. indent=4,
  154. cls=ScientificNotationEncoder,
  155. )
  156. def get_scalar_param(param_dict, param_name, param_default_value):
  157. return param_dict.get(param_name, param_default_value)
  158. def get_list_param(param_dict, param_name, param_default_value):
  159. return param_dict.get(param_name, param_default_value)
  160. def get_dict_param(param_dict, param_name, param_default_value):
  161. return param_dict.get(param_name, param_default_value)
  162. def dict_raise_error_on_duplicate_keys(ordered_pairs):
  163. """Reject duplicate keys."""
  164. d = dict((k, v) for k, v in ordered_pairs)
  165. if len(d) != len(ordered_pairs):
  166. counter = collections.Counter([pair[0] for pair in ordered_pairs])
  167. keys = [key for key, value in counter.items() if value > 1]
  168. raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys))
  169. return d