# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject ######################################### # DeepSpeed Activation Checkpointing ######################################### # Activation Checkpointing Allows to save memory by only keeping a select few #activations for the backpropagation. ACTIVATION_CHKPT_FORMAT = ''' Activation Checkpointing should be configured as: "session_params": { "activation_checkpointing": { "partitioned_activations": [true|false], "number_checkpoints": 100, "contiguous_memory_optimization": [true|false], "cpu_checkpointing": [true|false], "profile": [true|false], "synchronize_checkpoint_boundary": [true|false], } } ''' ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations' ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints' ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization' ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary' ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False ACT_CHKPT_PROFILE = 'profile' ACT_CHKPT_PROFILE_DEFAULT = False ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing' ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False ACT_CHKPT = 'activation_checkpointing' ACT_CHKPT_DEFAULT = { ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT, ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT, ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY: ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT, ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT, ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT } class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): def __init__(self, param_dict): super(DeepSpeedActivationCheckpointingConfig, self).__init__() self.partition_activations = None self.contiguous_memory_optimization = None self.cpu_checkpointing = None self.number_checkpoints = None self.synchronize_checkpoint_boundary = None self.profile = None if ACT_CHKPT in param_dict.keys(): act_chkpt_config_dict = param_dict[ACT_CHKPT] else: act_chkpt_config_dict = ACT_CHKPT_DEFAULT self._initialize(act_chkpt_config_dict) def _initialize(self, act_chkpt_config_dict): self.partition_activations = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_PARTITION_ACTIVATIONS, ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT) self.contiguous_memory_optimization = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION, ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT) self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_CPU_CHECKPOINTING, ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT) self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_NUMBER_CHECKPOINTS, ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT) self.profile = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_PROFILE, ACT_CHKPT_PROFILE_DEFAULT) self.synchronize_checkpoint_boundary = get_scalar_param(act_chkpt_config_dict, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY, ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT)