config-json.md 54 KB


title: "DeepSpeed Configuration JSON"

Batch Size Related Parameters

Note: train_batch_size must be equal to train_micro_batch_size_per_gpu * gradient_accumulation * number of GPUs. For simplicty, you can choose to only specify two of the three parameters, the last one will be inferred automatically by DeepSpeed. {: .notice--warning}

train_batch_size: [integer]

Value Example
The effective training batch size. This is the amount of data samples that leads to one step of model update. train_batch_size is aggregated by the batch size that a single GPU processes in one forward/backward pass (a.k.a., train_micro_batch_size_per_gpu), the gradient accumulation steps (a.k.a., gradient_accumulation_steps), and the number of GPUs. Can be omitted if both train_micro_batch_size_per_gpu and gradient_accumulation_steps are provided. 32

train_micro_batch_size_per_gpu: [integer]

Description Default
Batch size to be processed by one GPU in one step (without gradient accumulation). Can be omitted if both train_batch_size and gradient_accumulation_steps are provided. train_batch_size value

gradient_accumulation_steps: [integer]

Description Default
Number of training steps to accumulate gradients before averaging and applying them. This feature is sometimes useful to improve scalability since it results in less frequent communication of gradients between steps. Another impact of this feature is the ability to train with larger batch sizes per GPU. Can be omitted if both train_batch_size and train_micro_batch_size_per_gpu are provided. 1

Optimizer Parameters

optimizer: [dictionary]

Fields Value Example
type The optimizer name. DeepSpeed natively supports Adam, AdamW, OneBitAdam, Lamb, and OneBitLamb optimizers (See here for details) and will import other optimizers from torch. "Adam"
params Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for Adam). {"lr": 0.001, "eps": 1e-8}

Example of optimizer with Adam

"optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  }

The Adam optimizer also supports the following two params keys/values in addition to the standard parameters from torch.optim.Adam:

"params" key Description Default
torch_adam Use torch's implementation of adam instead of our fused adam implementation false
adam_w_mode Apply L2 regularization (also known as AdamW) true

Another example of optimizer with 1-bit Adam specific parameters is as follows.

"optimizer": {
    "type": "OneBitAdam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7,
      "freeze_step": 400,
      "cuda_aware": false,
      "comm_backend_name": "nccl"
    }
  }

The 1-bit Adam optimizer supports the following three params keys/values in addition to the standard Adam (learn more in our tutorial):

"params" key Description Default
freeze_step Number of warm up steps before 1-bit compression gets applied to the communication 100000
cuda_aware To indicate that the underlying MPI library supports CUDA-Aware communication false
comm_backend_name To indicate which backend implementation to use "nccl"

Another example of optimizer with 1-bit LAMB

"optimizer": {
    "type": "OneBitLamb",
    "params": {
      "lr": 11e-3,
      "weight_decay": 0.01,
      "bias_correction": false,
      "max_coeff": 0.3,
      "min_coeff": 0.01,
      "freeze_step": 1000,
      "cuda_aware": false,
      "comm_backend_name": "nccl",
      "coeff_beta": 0.9,
      "factor_max": 4.0,
      "factor_min": 0.5,
      "factor_threshold": 0.1
    }
  }

The 1-bit LAMB optimizer supports the following params keys/values in addition to the standard LAMB (learn more in our tutorial):

"params" key Description Default
max_coeff Scaling coefficient upper bound for original LAMB algorithm and 1-bit LAMB's warmup stage 10.0
min_coeff Scaling coefficient lower bound for original LAMB algorithm and 1-bit LAMB's warmup stage 0.01
freeze_step Number of warm up steps before 1-bit compression gets applied to the communication 100000
cuda_aware To indicate that the underlying MPI library supports CUDA-Aware communication false
comm_backend_name To indicate which backend implementation to use "nccl"
coeff_beta Coefficient used for computing running averages of lamb coefficient 0.9
factor_max Maximum value of scaling factor to the frozen lamb coefficient during compression stage 4.0
factor_min Minimum value of scaling factor to the frozen lamb coefficient during compression stage 0.5
factor_threshold Threshold of how much the scaling factor can fluctuate between steps 0.1

Scheduler Parameters

DeepSpeed calls the step() method of the scheduler at every training step when model_engine.step() is executed.

scheduler: [dictionary]

Fields Value Example
type The scheduler name. See here for list of support schedulers. "WarmupLR"
params Dictionary of parameters to instantiate scheduler. The parameter names should match scheduler constructor signature. {"warmup_min_lr": 0, "warmup_max_lr": 0.001}

Example of scheduler

 "scheduler": {
      "type": "WarmupLR",
      "params": {
          "warmup_min_lr": 0,
          "warmup_max_lr": 0.001,
          "warmup_num_steps": 1000
      }
  }

Communication options

fp32_allreduce: [boolean]

Description Default
During gradient averaging perform allreduce with 32 bit values false

prescale_gradients: [boolean]

Description Default
Scale gradients before doing allreduce false

gradient_predivide_factor: [float]

Description Default
Before gradient averaging predivide gradients by a specified factor, can sometimes help with fp16 stability when scaling to large numbers of GPUs 1.0

sparse_gradients: [boolean]

Description Default
Enable sparse compression of torch.nn.Embedding gradients. false

FP16 training options

Note: this mode cannot be combined with the amp mode described below. {: .notice--warning}

fp16: [dictionary]

Description Default
Configuration for using mixed precision/FP16 training that leverages NVIDIA's Apex package. An example, including the available dictionary keys is illustrated below. NOTE: this does not use Apex's AMP mode that allows for more flexibility in mixed precision training modes, this mode is similar to AMP's O2 mode. Please see AMP support below if you want to use more complex mixed precision modes. If you want to use ZeRO (currently) you must use this mode. None
"fp16": {
    "enabled": true,
    "loss_scale": 0,
    "initial_scale_power": 32,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "min_loss_scale": 1
}

fp16:enabled: [boolean]

Description Default
enabled is a fp16 parameter indicating whether or not FP16 training enabled. false

fp16:loss_scale: [float]

Description Default
loss_scale is a fp16 parameter representing the loss scaling value for FP16 training. The default value of 0.0 results in dynamic loss scaling, otherwise the value will be used for static fixed loss scaling. 0.0

fp16:initial_scale_power: [integer]

Description Default
initial_scale_power is a fp16 parameter representing the power of the initial dynamic loss scale value. The actual loss scale is computed as 2initial_scale_power. 32

fp16:loss_scale_window: [integer]

Description Default
loss_scale_window is a fp16 parameter representing the window over which to raise/lower the dynamic loss scale value. 1000

fp16:hysteresis: [integer]

Description Default
hysteresis is a fp16 parameter representing the delay shift in dynamic loss scaling. 2

fp16:min_loss_scale: [integer]

Description Default
min_loss_scale is a fp16 parameter representing the minimum dynamic loss scale value. 1000

Automatic mixed precision (AMP) training options

Note: this mode cannot be combined with the fp16 mode described above. In addition this mode is not currently compatible with ZeRO. {: .notice--warning}

amp: [dictionary]

Description Default
Configuration for using automatic mixed precision (AMP) training that leverages NVIDIA's Apex AMP package. An example, including the available dictionary keys is illustrated below. Is not compatible with fp16 mode above or ZeRO. Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the apex.amp.initialize documentation. None
"amp": {
    "enabled": true,
    ...
    "opt_level": "O1",
    ...
}

amp:enabled: [boolean]

Description Default
enabled is an amp parameter indicating whether or not AMP training is enabled. false

amp params: [various]

Description Default
Any parameters outside of "enabled" will be passed to AMP's initialize call, see the API and descriptions here at the apex.amp.initialize documentation. None

Gradient Clipping

gradient_clipping: [float]

Description Default
Enable gradient clipping with value 1.0

ZeRO Optimizations for FP16 Training

Enabling and configuring ZeRO memory optimizations

  "zero_optimization": {
    "stage": [0|1|2|3],
    "allgather_partitions": [true|false],
    "allgather_bucket_size": 5e8,
    "overlap_comm": false,
    "reduce_scatter": [true|false],
    "reduce_bucket_size": 5e8,
    "contiguous_gradients" : [true|false],
    "offload_param": {
      ...
    },
    "offload_optimizer": {
      ...
    },
    "stage3_max_live_parameters" : 1e9,
    "stage3_max_reuse_distance" : 1e9,
    "stage3_prefetch_bucket_size" : 5e8,
    "stage3_param_persistence_threshold" : 1e6,
    "sub_group_size" : 1e12,
    "elastic_checkpoint" : [true|false],
    "stage3_gather_fp16_weights_on_model_save": [true|false],
    "ignore_unused_parameters": [true|false]
    "round_robin_gradients": [true|false]
    }

zero_optimization: [dictionary]

Description Default
Enable ZeRO memory optimization wrapper for FP16 Training. Currently compatible only with Adam optimizer. false

stage: [integer]

Description Default
Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer to disabled, optimizer state partitioning, and optimizer+gradient state partitioning, and optimizer+gradient+parameter partitioning, respectively. 0

allgather_partitions: [boolean]

Description Default
Chooses between allgather collective or a series of broadcast collectives to gather updated parameters from all the GPUs at the end of each step true

allgather_bucket_size: [integer]

Description Default
Number of elements allgathered at a time. Limits the memory required for the allgather for large model sizes 5e8

overlap_comm: [boolean]

Description Default
Attempts to overlap the reduction of the gradients with backward computation false

reduce_scatter: [boolean]

Description Default
Uses reduce or reduce scatter instead of allreduce to average gradients true

reduce_bucket_size: [integer]

Description Default
Number of elements reduced/allreduced at a time. Limits the memory required for the allgather for large model sizes 5e8

contiguous_gradients: [boolean]

Description Default
Copies the gradients to a contiguous buffer as they are produced. Avoids memory fragmentation during backward pass. True

grad_hooks: [boolean]

Description Default
For use with ZeRO stage 1, enable backward hooks to reduce gradients during the backward pass or wait until the end of the backward pass. True

round_robin_gradients: [boolean]

Description Default
Stage 2 optimization for CPU offloading that parallelizes gradient copying to CPU memory among ranks by fine-grained gradient partitioning. Performance benefit grows with gradient accumulation steps (more copying between optimizer steps) or GPU count (increased parallelism). False

offload_param: [dictionary]

Description Default
Enable offloading of model parameters to CPU or NVMe. This frees up GPU memory for larger models or batch sizes. Valid only with stage 3. See here for more details. False

offload_optimizer: [dictionary]

Description Default
Enable offloading of optimizer state to CPU or NVMe, and optimizer computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 2 and 3. See here for more details. False

stage3_max_live_parameters: [integer]

Description Default
The maximum number of parameters resident per GPU before releasing. Smaller values use less memory, but perform more communication. 1e9

stage3_max_reuse_distance: [integer]

Description Default
Do not release a parameter if it will be reused within this threshold of parameters. Smaller values use less memory, but perform more communication. 1e9

stage3_prefetch_bucket_size: [integer]

Description Default
The size of the fixed buffer for prefetching parameters. Smaller values use less memory, but can increase stalls due to communication. 5e8

stage3_param_persistence_threshold: [integer]

Description Default
Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). 1e6

stage3_gather_fp16_weights_on_model_save: [boolean]

Description Default
Consolidate the weights before saving the model by save_fp16_model(). Since the weights are partitioned across GPUs, they aren't part of state_dict, so this function automatically gather the weights when this option is enabled and then saves the fp16 model weights. False

cpu_offload: [boolean]

Deprecated: cpu_offload is deprecated and will be removed in future, please use offload_optimizer instead. {: .notice--warning}

Description Default
Enable offloading of optimizer memory and computation to CPU. This frees up GPU memory for larger models or batch sizes. Valid only with stage 2. False

Parameter offloading

Enabling and configuring ZeRO optimization of parameter offloading to CPU/NVMe. Available only with ZeRO stage 3.

  "offload_param": {
    "device": "[none|cpu|nvme]",
    "nvme_path": "/local_nvme",
    "pin_memory": [true|false],
    "buffer_count": 5,
    "buffer_size": 1e8,
    "max_in_cpu": 1e9
  }

device: [string]

Description Default
Device memory to offload model parameters. Supported options are cpu and nvme. cpu

nvme_path: [string]

Description Default
Filesystem path for NVMe device for parameter offloading. /local_nvme

pin_memory: [boolean]

Description Default
Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. false

buffer_count: [integer]

Description Default
Number of buffers in buffer pool for parameter offloading to NVMe. 5

buffer_size: [integer]

Description Default
Size of buffers in buffer pool for parameter offloading to NVMe. 1e8

max_in_cpu: [integer]

Description Default
Number of parameter elements to maintain in CPU memory when offloading to NVMe is enabled. 1e9

Optimizer offloading

Enabling and configuring ZeRO optimization of offloading optimizer computation to CPU and state to CPU/NVMe. CPU offloading is available with ZeRO stage 2 or 3. NVMe offloading is available only with ZeRO stage 3.

  "offload_optimizer": {
    "device": "[none|cpu|nvme]",
    "nvme_path": "/local_nvme",
    "pin_memory": [true|false],
    "buffer_count": 4,
    "fast_init": false
  }

device: [string]

Description Default
Device memory to offload optimizer state. Supported options are cpu and nvme. Optimizer computation is offload to CPU regardless of device option. cpu

nvme_path: [string]

Description Default
Filesystem path for NVMe device for optimizer state offloading. /local_nvme

pin_memory: [boolean]

Description Default
Offload to page-locked CPU memory. This could boost throughput at the cost of extra memory overhead. false

buffer_count: [integer]

Description Default
Number of buffers in buffer pool for optimizer state offloading to NVMe. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance). 4

fast_init: [boolean]

Description Default
Enable fast optimizer initialization when offloading to NVMe. false

Asynchronous I/O

Configuring the asynchronous I/O module for offloading parameter and optimizer states to persistent (NVMe) storage. This module uses Linux native asynchronous I/O (libaio).

  "aio": {
    "block_size": 1048576,
    "queue_depth": 8,
    "thread_count": 1,
    "single_submit": false,
    "overlap_events": true
  }

block_size: [integer]

Description Default
I/O block size in bytes. 1048576

queue_depth: [integer]

Description Default
I/O queue depth. 8

thread_count: [integer]

Description Default
Intra-request parallelism for each read/write submitted by a user thread. 1

single_submit: [boolean]

Description Default
Submit requests to storage device as multiple individual requests as opposed to one block of requests. false

overlap_events: [boolean]

Description Default
Submit requests to storage device in an overlapped fashion without waiting for completion of earlier requests. true

ignore_unused_parameters: [boolean]

Description Default
Unused parameters in modules may be unexpected in static networks, but could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to False by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. True

Logging

steps_per_print: [integer]

Description Default
Print progress report every N training steps. The report includes the number of training steps, number of skipped optimizer updates (likely due to overflows in mixed-precision training), current learning rate, and current momentum. 10

wall_clock_breakdown: [boolean]

Description Default
Enable timing of the latency of forward/backward/update training phases false

dump_state: [boolean]

Description Default
Print out state information of DeepSpeed object after initialization false

Flops Profiler

{
  "flops_profiler": {
    "enabled": false,
    "profile_step": 1,
    "module_depth": -1,
    "top_modules": 1,
    "detailed": true,
    "output_file": null,
    }
}

enabled: [boolean]

Description Default
Enables the flops profiler. This would also enables wall_clock_breakdown false

profile_step: [integer]

Description Default
The global training step at which to profile. Note that warm up steps are needed for accurate time measurement. 1

module_depth: [integer]

Description Default
The depth of the model at which to print the aggregated module information. When set to -1, it prints information from the top module to the innermost modules (the maximum depth). -1

top_modules: [integer]

Description Default
Limits the aggregated profile output to the number of top modules specified. 1

detailed: [boolean]

Description Default
Whether to print the detailed model profile. true

output_file: [string]

Description Default
Path to the output file. If None, the profiler prints to stdout.. null

Activation Checkpointing

  "activation_checkpointing": {
    "partition_activations": false,
    "cpu_checkpointing": false,
    "contiguous_memory_optimization": false,
    "number_checkpoints": null,
    "synchronize_checkpoint_boundary": false,
    "profile": false
    }

partition_activations: [boolean]

Description Default
Enables partition activation when used with model parallelism false

cpu_checkpointing: [boolean]

Description Default
Offloads partitioned activations to CPU if partition_activations is enabled false

contiguous_memory_optimization: [boolean]

Description Default
Copies partitioned activations so that they are contiguous in memory false

number_checkpoints: [integer]

Description Default
Total number of activation checkpoints used to allocate memory buffer for contiguous_memoty_optimization None

synchronize_checkpoint_boundary: [boolean]

Description Default
Inserts torch.cuda.synchronize() at each checkpoint boundary. false

profile: [boolean]

Description Default
Logs the forward and backward time for each checkpoint function false

Sparse Attention

sparse_attention: [dictionary]

Fields Value Example
mode A string determining sparsity structure type. Deepspeed currently supports "dense", "fixed", "bigbird", "bslongformer", and "variable". "fixed"
block An integer determining the block size. Current implementation of sparse self-attention is based on blocked sparse matrices. In which this parameter defines size of such blocks, Block X Block. 16
different_layout_per_head A boolean determining if each head should be assigned a different sparsity layout; this will be satisfied based on availability. false
num_local_blocks An integer determining the number of random blocks in each block row; only used in "fixed" mode. 4
num_global_blocks An integer determining how many consecutive blocks in a local window is used as the representative of the window for global attention; used in "fixed" and "bigbird" modes. 1
attention A string determining attention type. Attention can be "unidirectional", such as autoregressive models, in which tokens attend only to tokens appear before them in the context. Considering that, the upper triangular of attention matrix is empty. Or it can be "bidirectional", such as BERT, in which tokens can attend to any other tokens before or after them. Then, the upper triangular part of the attention matrix is mirror of the lower triangular; used in "fixed" and "variable" modes. "bidirectional"
horizontal_global_attention A boolean determining if blocks that are global representative of a local window, also attend to all other blocks. This is valid only if attention type is "bidirectional". Looking at the attention matrix, that means global attention not only includes the vertical blocks, but also horizontal blocks; used in "fixed" and "variable" modes. false
num_different_global_patterns An integer determining number of different global attentions layouts. While global attention can be fixed by which block/s are representative of any local window, since there are multi-heads, each head can use a different global representative; used only in "fixed" mode. 4
num_random_blocks An integer determining the number of random blocks in each block row; used in "variable" and "bigbird" modes. 0
local_window_blocks A list of integers determining the number of blocks in each local attention window. It assumes first number determines # of blocks in the first local window, second the second window, ..., and the last number determines the number of blocks in the remaining local windows; only used in "variable" mode. [4]
global_block_indices A list of integers determining which blocks are considered as global attention. Given indices, determine the blocks that all other token blocks attend to and they attend to all other token blocks. Notice that if global_block_end_indices parameter is set, this parameter is used as starting index of each global window; used in "variable" and "bslongformer" modes. [0]
global_block_end_indices A list of integers determining end indices of global window blocks. By default this is not used. But if it is set, it must have the same size of global_block_indices parameter, and combining this two parameters, for each index i, blocks from global_block_indices[i] to global_block_end_indices[i], exclusive, are considered as global attention; used in "variable" and "bslongformer" modes. None
num_sliding_window_blocks An integer determining the number of blocks in sliding local attention window; used in "bigbird" and "bslongformer" modes. 3

Example of sparse_attention

  "sparse_attention": {
    "mode": "fixed",
    "block": 16,
    "different_layout_per_head": true,
    "num_local_blocks": 4,
    "num_global_blocks": 1,
    "attention": "bidirectional",
    "horizontal_global_attention": false,
    "num_different_global_patterns": 4,
    "num_random_blocks": 0,
    "local_window_blocks": [4],
    "global_block_indices": [0],
    "global_block_end_indices": None,
    "num_sliding_window_blocks": 3
  }

Curriculum Learning

  "curriculum_learning": {
    "enabled": true,
    "curriculum_type": "seqlen",
    "min_difficulty": 8,
    "max_difficulty": 1024,
    "schedule_type": "fixed_linear",
    "schedule_config": {
      "total_curriculum_step": 40000,
      "difficulty_step": 8
    }
  }

enabled: [boolean]

Description Default
Set to true to enable curriculum learning false

curriculum_type: [string]

Description Default
Type of curriculum difficulty metric. Currently support seqlen. N/A

min_difficulty: [integer]

Description Default
The starting difficulty level N/A

max_difficulty: [integer]

Description Default
The ending difficulty level N/A

schedule_type: [string]

Description Default
Type of curriculum schedule. Currently support fixed_linear, fixed_root, and fixed_discrete. N/A

total_curriculum_step: [integer]

Description Default
Total number of steps for the curriculum learning. One of the schedule_config when the fixed_linear and fixed_root schedule_type are used. N/A

difficulty_step: [integer]

Description Default
At any time, the curriculum learning difficulty must be multiple of this difficulty_step. Set this to multiple of 8 (for FP16 data) or 16 (for INT8 data) to enable NVIDIA Tensor Core acceleration. One of the schedule_config when the fixed_linear and fixed_root schedule_type are used. N/A

root_degree: [integer]

Description Default
Root degree of the curriculum schedule function. One of the schedule_config when the fixed_root schedule_type is used. N/A

difficulty: [list of integer]

Description Default
List of difficulty levels to be used during schedule. One of the schedule_config when the fixed_discrete schedule_type is used. N/A

max_step: [list of integer]

Description Default
List of which step to change difficulty level. One of the schedule_config when the fixed_discrete schedule_type is used. N/A