config.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import json
  5. from .constants import *
  6. class ElasticityError(Exception):
  7. """
  8. Base exception for all elasticity related errors
  9. """
  10. class ElasticityConfigError(ElasticityError):
  11. """
  12. Elasticity configuration error
  13. """
  14. class ElasticityIncompatibleWorldSize(ElasticityError):
  15. """
  16. Attempting to run a world size that is incompatible with a given elastic config
  17. """
  18. class ElasticityConfig:
  19. """
  20. Elastic config object, constructed from a param dictionary that only contains elastic
  21. config parameters, example below:
  22. If elasticity is enabled, user must specify (at least) max_train_batch_size
  23. and micro_batch_sizes.
  24. {
  25. "enabled": true,
  26. "max_train_batch_size": 2000,
  27. "micro_batch_sizes": [2,4,6],
  28. "min_gpus": 1,
  29. "max_gpus" : 10000
  30. "min_time": 20
  31. "ignore_non_elastic_batch_info": false
  32. "version": 0.1
  33. }
  34. """
  35. def __init__(self, param_dict):
  36. self.enabled = param_dict.get(ENABLED, ENABLED_DEFAULT)
  37. if self.enabled:
  38. if MAX_ACCEPTABLE_BATCH_SIZE in param_dict:
  39. self.max_acceptable_batch_size = param_dict[MAX_ACCEPTABLE_BATCH_SIZE]
  40. else:
  41. raise ElasticityConfigError(
  42. f"Elasticity config missing {MAX_ACCEPTABLE_BATCH_SIZE}")
  43. if MICRO_BATCHES in param_dict:
  44. self.micro_batches = param_dict[MICRO_BATCHES]
  45. else:
  46. raise ElasticityConfigError(f"Elasticity config missing {MICRO_BATCHES}")
  47. else:
  48. self.max_acceptable_batch_size = param_dict.get(
  49. MAX_ACCEPTABLE_BATCH_SIZE,
  50. MAX_ACCEPTABLE_BATCH_SIZE_DEFAULT)
  51. self.micro_batches = param_dict.get(MICRO_BATCHES, MICRO_BATCHES_DEFAULT)
  52. if not isinstance(self.micro_batches, list):
  53. raise ElasticityConfigError(
  54. f"Elasticity expected value of {MICRO_BATCHES} to be a "
  55. f"list of micro batches, instead is: {type(self.micro_batches)}, containing: {self.micro_batches}"
  56. )
  57. if not all(map(lambda m: isinstance(m, int), self.micro_batches)):
  58. raise ElasticityConfigError(
  59. f"Elasticity expected {MICRO_BATCHES} to only contain a list of integers, "
  60. f"instead contains: f{self.micro_batches}")
  61. if not all(map(lambda m: m > 0, self.micro_batches)):
  62. raise ElasticityConfigError(
  63. f"Elasticity expected {MICRO_BATCHES} to only contain positive integers, "
  64. f"instead contains: f{self.micro_batches}")
  65. self.min_gpus = param_dict.get(MIN_GPUS, MIN_GPUS_DEFAULT)
  66. self.max_gpus = param_dict.get(MAX_GPUS, MAX_GPUS_DEFAULT)
  67. if self.min_gpus < 1 or self.max_gpus < 1:
  68. raise ElasticityConfigError(
  69. "Elasticity min/max gpus must be > 0, "
  70. f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
  71. if self.max_gpus < self.min_gpus:
  72. raise ElasticityConfigError(
  73. "Elasticity min_gpus cannot be greater than max_gpus, "
  74. f"given min_gpus: {self.min_gpus}, max_gpus: {self.max_gpus}")
  75. self.model_parallel_size = param_dict.get(MODEL_PARLLEL_SIZE,
  76. MODEL_PARLLEL_SIZE_DEFAULT)
  77. if self.model_parallel_size < 1:
  78. raise ElasticityConfigError(
  79. "Model-Parallel size cannot be less than 1, "
  80. f"given model-parallel size: {self.model_parallel_size}")
  81. self.num_gpus_per_node = param_dict.get(NUM_GPUS_PER_NODE,
  82. NUM_GPUS_PER_NODE_DEFAULT)
  83. if self.num_gpus_per_node < 1:
  84. raise ElasticityConfigError(
  85. "Number of GPUs per node cannot be less than 1, "
  86. f"given number of GPUs per node: {self.num_gpus_per_node}")
  87. self.min_time = param_dict.get(MIN_TIME, MIN_TIME_DEFAULT)
  88. if self.min_time < 0:
  89. raise ElasticityConfigError(
  90. f"Elasticity min time needs to be >= 0: given {self.min_time}")
  91. self.version = param_dict.get(VERSION, VERSION_DEFAULT)
  92. self.prefer_larger_batch_size = param_dict.get(PREFER_LARGER_BATCH,
  93. PREFER_LARGER_BATCH_DEFAULT)
  94. self.ignore_non_elastic_batch_info = param_dict.get(
  95. IGNORE_NON_ELASTIC_BATCH_INFO,
  96. IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT)
  97. def repr(self):
  98. return self.__dict__
  99. def __repr__(self):
  100. return json.dumps(self.__dict__, sort_keys=True, indent=4)