config.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
  5. #########################################
  6. # DeepSpeed Activation Checkpointing
  7. #########################################
  8. # Activation Checkpointing Allows to save memory by only keeping a select few
  9. #activations for the backpropagation.
  10. ACTIVATION_CHKPT_FORMAT = '''
  11. Activation Checkpointing should be configured as:
  12. "session_params": {
  13. "activation_checkpointing": {
  14. "partitioned_activations": [true|false],
  15. "number_checkpoints": 100,
  16. "contiguous_memory_optimization": [true|false],
  17. "cpu_checkpointing": [true|false],
  18. "profile": [true|false],
  19. "synchronize_checkpoint_boundary": [true|false],
  20. }
  21. }
  22. '''
  23. ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations'
  24. ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False
  25. ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints'
  26. ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None
  27. ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization'
  28. ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False
  29. ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary'
  30. ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False
  31. ACT_CHKPT_PROFILE = 'profile'
  32. ACT_CHKPT_PROFILE_DEFAULT = False
  33. ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing'
  34. ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False
  35. ACT_CHKPT = 'activation_checkpointing'
  36. ACT_CHKPT_DEFAULT = {
  37. ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT,
  38. ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT,
  39. ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT,
  40. ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY: ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT,
  41. ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT,
  42. ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT
  43. }
  44. class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject):
  45. def __init__(self, param_dict):
  46. super(DeepSpeedActivationCheckpointingConfig, self).__init__()
  47. self.partition_activations = None
  48. self.contiguous_memory_optimization = None
  49. self.cpu_checkpointing = None
  50. self.number_checkpoints = None
  51. self.synchronize_checkpoint_boundary = None
  52. self.profile = None
  53. if ACT_CHKPT in param_dict.keys():
  54. act_chkpt_config_dict = param_dict[ACT_CHKPT]
  55. else:
  56. act_chkpt_config_dict = ACT_CHKPT_DEFAULT
  57. self._initialize(act_chkpt_config_dict)
  58. def _initialize(self, act_chkpt_config_dict):
  59. self.partition_activations = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_PARTITION_ACTIVATIONS,
  60. ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT)
  61. self.contiguous_memory_optimization = get_scalar_param(act_chkpt_config_dict,
  62. ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION,
  63. ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT)
  64. self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_CPU_CHECKPOINTING,
  65. ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT)
  66. self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_NUMBER_CHECKPOINTS,
  67. ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT)
  68. self.profile = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_PROFILE, ACT_CHKPT_PROFILE_DEFAULT)
  69. self.synchronize_checkpoint_boundary = get_scalar_param(act_chkpt_config_dict,
  70. ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY,
  71. ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)