cpu_adam.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import math
  5. import torch
  6. import time
  7. from pathlib import Path
  8. from ..op_builder import CPUAdamBuilder
  9. from deepspeed.utils.logging import should_log_le
  10. class DeepSpeedCPUAdam(torch.optim.Optimizer):
  11. optimizer_id = 0
  12. def __init__(self,
  13. model_params,
  14. lr=1e-3,
  15. bias_correction=True,
  16. betas=(0.9,
  17. 0.999),
  18. eps=1e-8,
  19. weight_decay=0,
  20. amsgrad=False,
  21. adamw_mode=True,
  22. fp32_optimizer_states=True):
  23. """Fast vectorized implementation of two variations of Adam optimizer on CPU:
  24. * Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
  25. * AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101)
  26. DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W).
  27. In order to apply this optimizer, the model requires to have its master parameter (in FP32)
  28. reside on the CPU memory.
  29. To train on a heterogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
  30. the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
  31. with minimal impact on training throughput. DeepSpeedCPUAdam plays an important role to minimize
  32. the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
  33. (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
  34. For calling step function, there are two options available: (1) update optimizer's states and (2) update
  35. optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
  36. option can bring 30% higher throughput than the doing the copy separately using option one.
  37. .. note::
  38. We recommend using our `config
  39. <https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
  40. to allow :meth:`deepspeed.initialize` to build this optimizer
  41. for you.
  42. Arguments:
  43. model_params (iterable): iterable of parameters to optimize or dicts defining
  44. parameter groups.
  45. lr (float, optional): learning rate. (default: 1e-3)
  46. betas (Tuple[float, float], optional): coefficients used for computing
  47. running averages of gradient and its square. (default: (0.9, 0.999))
  48. eps (float, optional): term added to the denominator to improve
  49. numerical stability. (default: 1e-8)
  50. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  51. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  52. algorithm from the paper `On the Convergence of Adam and Beyond`_
  53. (default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
  54. adamw_mode: select between Adam and AdamW implementations (default: AdamW)
  55. full_precision_optimizer_states: creates momementum and variance in full precision regardless of
  56. the precision of the parameters (default: True)
  57. """
  58. default_args = dict(lr=lr,
  59. betas=betas,
  60. eps=eps,
  61. weight_decay=weight_decay,
  62. bias_correction=bias_correction,
  63. amsgrad=amsgrad)
  64. super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
  65. self.opt_id = DeepSpeedCPUAdam.optimizer_id
  66. DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1
  67. self.adam_w_mode = adamw_mode
  68. self.fp32_optimizer_states = fp32_optimizer_states
  69. self.ds_opt_adam = CPUAdamBuilder().load()
  70. self.ds_opt_adam.create_adam(self.opt_id,
  71. lr,
  72. betas[0],
  73. betas[1],
  74. eps,
  75. weight_decay,
  76. adamw_mode,
  77. should_log_le("info"))
  78. def __del__(self):
  79. # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
  80. # is used multiple times in the same process (notebook or pytest worker)
  81. self.ds_opt_adam.destroy_adam(self.opt_id)
  82. def __setstate__(self, state):
  83. super(DeepSpeedCPUAdam, self).__setstate__(state)
  84. for group in self.param_groups:
  85. group.setdefault('amsgrad', False)
  86. @torch.no_grad()
  87. def step(self, closure=None, fp16_param_groups=None):
  88. """Update the model parameters.
  89. .. note::
  90. This method will be called internally by ZeRO-Offload. DeepSpeed
  91. users should still use ``engine.step()`` as shown in the
  92. `Getting Started
  93. <https://www.deepspeed.ai/getting-started/#training>`_ guide.
  94. Args:
  95. closure (callable, optional): closure to compute the loss.
  96. Defaults to ``None``.
  97. fp16_param_groups: FP16 GPU parameters to update. Performing the
  98. copy here reduces communication time. Defaults to ``None``.
  99. Returns:
  100. loss: if ``closure`` is provided. Otherwise ``None``.
  101. """
  102. loss = None
  103. if closure is not None:
  104. with torch.enable_grad():
  105. loss = closure()
  106. # converting the fp16 params to a group of parameter
  107. if type(fp16_param_groups) is list:
  108. if type(fp16_param_groups[0]) is not list:
  109. fp16_param_groups = [fp16_param_groups]
  110. elif fp16_param_groups is not None:
  111. fp16_param_groups = [[fp16_param_groups]]
  112. for group_id, group in enumerate(self.param_groups):
  113. for param_id, p in enumerate(group['params']):
  114. if p.grad is None:
  115. continue
  116. state = self.state[p]
  117. # State initialization
  118. if len(state) == 0:
  119. #print(f'group {group_id} param {param_id} = {p.numel()}')
  120. state['step'] = 0
  121. #use full precision by default unless self.fp32_optimizer_states is off
  122. state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
  123. # gradient momentums
  124. state['exp_avg'] = torch.zeros_like(p.data,
  125. dtype=state_dtype,
  126. device='cpu')
  127. #memory_format=torch.preserve_format)
  128. # gradient variances
  129. state['exp_avg_sq'] = torch.zeros_like(p.data,
  130. dtype=state_dtype,
  131. device='cpu')
  132. #memory_format=torch.preserve_format)
  133. state['step'] += 1
  134. beta1, beta2 = group['betas']
  135. if fp16_param_groups is not None:
  136. self.ds_opt_adam.adam_update_copy(
  137. self.opt_id,
  138. state['step'],
  139. group['lr'],
  140. beta1,
  141. beta2,
  142. group['eps'],
  143. group['weight_decay'],
  144. group['bias_correction'],
  145. p.data,
  146. p.grad.data,
  147. state['exp_avg'],
  148. state['exp_avg_sq'],
  149. fp16_param_groups[group_id][param_id].data)
  150. else:
  151. self.ds_opt_adam.adam_update(self.opt_id,
  152. state['step'],
  153. group['lr'],
  154. beta1,
  155. beta2,
  156. group['eps'],
  157. group['weight_decay'],
  158. group['bias_correction'],
  159. p.data,
  160. p.grad.data,
  161. state['exp_avg'],
  162. state['exp_avg_sq'])
  163. return loss