cpu_adam.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from cpuinfo import get_cpu_info
  6. from deepspeed.utils import logger
  7. from deepspeed.utils.logging import should_log_le
  8. from deepspeed.ops.op_builder import CPUAdamBuilder
  9. class DeepSpeedCPUAdam(torch.optim.Optimizer):
  10. optimizer_id = 0
  11. def __init__(self,
  12. model_params,
  13. lr=1e-3,
  14. bias_correction=True,
  15. betas=(0.9, 0.999),
  16. eps=1e-8,
  17. weight_decay=0,
  18. amsgrad=False,
  19. adamw_mode=True,
  20. fp32_optimizer_states=True):
  21. """Fast vectorized implementation of two variations of Adam optimizer on CPU:
  22. * Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
  23. * AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101)
  24. DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W).
  25. In order to apply this optimizer, the model requires to have its master parameter (in FP32)
  26. reside on the CPU memory.
  27. To train on a heterogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
  28. the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
  29. with minimal impact on training throughput. DeepSpeedCPUAdam plays an important role to minimize
  30. the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
  31. (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
  32. For calling step function, there are two options available: (1) update optimizer's states and (2) update
  33. optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
  34. option can bring 30% higher throughput than the doing the copy separately using option one.
  35. .. note::
  36. We recommend using our `config
  37. <https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
  38. to allow :meth:`deepspeed.initialize` to build this optimizer
  39. for you.
  40. Arguments:
  41. model_params (iterable): iterable of parameters to optimize or dicts defining
  42. parameter groups.
  43. lr (float, optional): learning rate. (default: 1e-3)
  44. betas (Tuple[float, float], optional): coefficients used for computing
  45. running averages of gradient and its square. (default: (0.9, 0.999))
  46. eps (float, optional): term added to the denominator to improve
  47. numerical stability. (default: 1e-8)
  48. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  49. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  50. algorithm from the paper `On the Convergence of Adam and Beyond`_
  51. (default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
  52. adamw_mode: select between Adam and AdamW implementations (default: AdamW)
  53. fp32_optimizer_states: creates momentum and variance in full precision regardless of
  54. the precision of the parameters (default: True)
  55. """
  56. default_args = dict(lr=lr,
  57. betas=betas,
  58. eps=eps,
  59. weight_decay=weight_decay,
  60. bias_correction=bias_correction,
  61. amsgrad=amsgrad)
  62. super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
  63. cpu_info = get_cpu_info()
  64. self.cpu_vendor = cpu_info["vendor_id_raw"].lower() if "vendor_id_raw" in cpu_info else "unknown"
  65. if "amd" in self.cpu_vendor:
  66. for group_id, group in enumerate(self.param_groups):
  67. for param_id, p in enumerate(group['params']):
  68. if p.dtype == torch.half:
  69. logger.warning("FP16 params for CPUAdam may not work on AMD CPUs")
  70. break
  71. else:
  72. continue
  73. break
  74. self.opt_id = DeepSpeedCPUAdam.optimizer_id
  75. DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
  76. self.adam_w_mode = adamw_mode
  77. self.fp32_optimizer_states = fp32_optimizer_states
  78. self.ds_opt_adam = CPUAdamBuilder().load()
  79. self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode,
  80. should_log_le("info"))
  81. def __del__(self):
  82. # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
  83. # is used multiple times in the same process (notebook or pytest worker)
  84. self.ds_opt_adam.destroy_adam(self.opt_id)
  85. def __setstate__(self, state):
  86. super(DeepSpeedCPUAdam, self).__setstate__(state)
  87. for group in self.param_groups:
  88. group.setdefault('amsgrad', False)
  89. @torch.no_grad()
  90. def step(self, closure=None):
  91. """Update the model parameters.
  92. .. note::
  93. This method will be called internally by ZeRO-Offload. DeepSpeed
  94. users should still use ``engine.step()`` as shown in the
  95. `Getting Started
  96. <https://www.deepspeed.ai/getting-started/#training>`_ guide.
  97. Args:
  98. closure (callable, optional): closure to compute the loss.
  99. Defaults to ``None``.
  100. Returns:
  101. loss: if ``closure`` is provided. Otherwise ``None``.
  102. """
  103. loss = None
  104. if closure is not None:
  105. with torch.enable_grad():
  106. loss = closure()
  107. # intended device for step
  108. device = torch.device('cpu')
  109. for group_id, group in enumerate(self.param_groups):
  110. for param_id, p in enumerate(group['params']):
  111. if p.grad is None:
  112. continue
  113. assert p.device == device, f"CPUAdam param is on {p.device} and must be 'cpu', make " \
  114. "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
  115. state = self.state[p]
  116. # State initialization
  117. if len(state) == 0:
  118. #print(f'group {group_id} param {param_id} = {p.numel()}')
  119. state['step'] = 0
  120. #use full precision by default unless self.fp32_optimizer_states is off
  121. state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
  122. # gradient momentums
  123. state['exp_avg'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
  124. #memory_format=torch.preserve_format)
  125. # gradient variances
  126. state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device=device)
  127. #memory_format=torch.preserve_format)
  128. state['step'] += 1
  129. beta1, beta2 = group['betas']
  130. self.ds_opt_adam.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
  131. group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
  132. state['exp_avg'], state['exp_avg_sq'])
  133. return loss