cpu_adagrad.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. from deepspeed.ops.op_builder import CPUAdagradBuilder
  6. from deepspeed.utils.logging import should_log_le
  7. class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
  8. optimizer_id = 0
  9. def __init__(self,
  10. model_params,
  11. lr=1e-2,
  12. eps=1e-10,
  13. weight_decay=0,
  14. amsgrad=False,
  15. fp32_optimizer_states=True):
  16. default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
  17. super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args)
  18. self.opt_id = DeepSpeedCPUAdagrad.optimizer_id
  19. DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1
  20. self.fp32_optimizer_states = fp32_optimizer_states
  21. self.ds_opt_adagrad = CPUAdagradBuilder().load()
  22. self.ds_opt_adagrad.create_adagrad(self.opt_id,
  23. lr,
  24. eps,
  25. weight_decay,
  26. should_log_le("info"))
  27. def __del__(self):
  28. # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
  29. # is used multiple times in the same process (notebook or pytest worker)
  30. self.ds_opt_adagrad.destroy_adagrad(self.opt_id)
  31. def __setstate__(self, state):
  32. super(DeepSpeedCPUAdagrad, self).__setstate__(state)
  33. for group in self.param_groups:
  34. group.setdefault('amsgrad', False)
  35. @torch.no_grad()
  36. def step(self, closure=None, fp16_param_groups=None):
  37. """Update the model parameters.
  38. .. note::
  39. This method will be called internally by ZeRO-Offload. DeepSpeed
  40. users should still use ``engine.step()`` as shown in the
  41. `Getting Started
  42. <https://www.deepspeed.ai/getting-started/#training>`_ guide.
  43. Args:
  44. closure (callable, optional): closure to compute the loss.
  45. Defaults to ``None``.
  46. fp16_param_groups: FP16 GPU parameters to update. Performing the
  47. copy here reduces communication time. Defaults to ``None``.
  48. Returns:
  49. loss: if ``closure`` is provided. Otherwise ``None``.
  50. """
  51. loss = None
  52. if closure is not None:
  53. with torch.enable_grad():
  54. loss = closure()
  55. # intended device for step
  56. device = torch.device('cpu')
  57. for group_id, group in enumerate(self.param_groups):
  58. for param_id, p in enumerate(group['params']):
  59. if p.grad is None:
  60. continue
  61. assert p.device == device, f"CPUAdagrad param is on {p.device} and must be 'cpu', make " \
  62. "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
  63. state = self.state[p]
  64. # State initialization
  65. if len(state) == 0:
  66. #print(f'group {group_id} param {param_id} = {p.numel()}')
  67. state['step'] = 0
  68. #use full precision by default unless self.fp32_optimizer_states is off
  69. state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
  70. #memory_format=torch.preserve_format)
  71. # gradient variances
  72. state['exp_avg_sq'] = torch.zeros_like(p.data,
  73. dtype=state_dtype,
  74. device='cpu')
  75. #memory_format=torch.preserve_format)
  76. state['step'] += 1
  77. if p.grad.is_sparse == True:
  78. sparse_param = p.sparse_mask(p.grad)
  79. sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad)
  80. self.ds_opt_adagrad.adagrad_update(self.opt_id,
  81. state['step'],
  82. group['lr'],
  83. group['eps'],
  84. group['weight_decay'],
  85. sparse_param.values(),
  86. p.grad.values(),
  87. sparse_exp_avg_sq.values())
  88. p[sparse_param.indices()] = sparse_param.values()
  89. state['exp_avg_sq'][
  90. sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values()
  91. if fp16_param_groups is not None:
  92. fp16_param_groups[group_id][param_id][
  93. sparse_param.indices()] = sparse_param.values()
  94. else:
  95. if fp16_param_groups is not None:
  96. self.ds_opt_adagrad.adagrad_update_copy(
  97. self.opt_id,
  98. state['step'],
  99. group['lr'],
  100. group['eps'],
  101. group['weight_decay'],
  102. p.data,
  103. p.grad.data,
  104. state['exp_avg_sq'],
  105. fp16_param_groups[group_id][param_id].data)
  106. else:
  107. self.ds_opt_adagrad.adagrad_update(self.opt_id,
  108. state['step'],
  109. group['lr'],
  110. group['eps'],
  111. group['weight_decay'],
  112. p.data,
  113. p.grad.data,
  114. state['exp_avg_sq'])
  115. return loss