cpu_adagrad.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.ops.op_builder import CPUAdagradBuilder
  6. from deepspeed.utils.logging import should_log_le
  7. class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
  8. optimizer_id = 0
  9. def __init__(self, model_params, lr=1e-2, eps=1e-10, weight_decay=0, amsgrad=False, fp32_optimizer_states=True):
  10. default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
  11. super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args)
  12. self.opt_id = DeepSpeedCPUAdagrad.optimizer_id
  13. DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1
  14. self.fp32_optimizer_states = fp32_optimizer_states
  15. self.ds_opt_adagrad = CPUAdagradBuilder().load()
  16. self.ds_opt_adagrad.create_adagrad(self.opt_id, lr, eps, weight_decay, should_log_le("info"))
  17. def __del__(self):
  18. # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
  19. # is used multiple times in the same process (notebook or pytest worker)
  20. self.ds_opt_adagrad.destroy_adagrad(self.opt_id)
  21. def __setstate__(self, state):
  22. super(DeepSpeedCPUAdagrad, self).__setstate__(state)
  23. for group in self.param_groups:
  24. group.setdefault('amsgrad', False)
  25. @torch.no_grad()
  26. def step(self, closure=None, fp16_param_groups=None):
  27. """Update the model parameters.
  28. .. note::
  29. This method will be called internally by ZeRO-Offload. DeepSpeed
  30. users should still use ``engine.step()`` as shown in the
  31. `Getting Started
  32. <https://www.deepspeed.ai/getting-started/#training>`_ guide.
  33. Args:
  34. closure (callable, optional): closure to compute the loss.
  35. Defaults to ``None``.
  36. fp16_param_groups: FP16 GPU parameters to update. Performing the
  37. copy here reduces communication time. Defaults to ``None``.
  38. Returns:
  39. loss: if ``closure`` is provided. Otherwise ``None``.
  40. """
  41. loss = None
  42. if closure is not None:
  43. with torch.enable_grad():
  44. loss = closure()
  45. # intended device for step
  46. device = torch.device('cpu')
  47. for group_id, group in enumerate(self.param_groups):
  48. for param_id, p in enumerate(group['params']):
  49. if p.grad is None:
  50. continue
  51. assert p.device == device, f"CPUAdagrad param is on {p.device} and must be 'cpu', make " \
  52. "sure you enabled 'offload_optimizer': 'cpu' in your ZeRO config."
  53. state = self.state[p]
  54. # State initialization
  55. if len(state) == 0:
  56. #print(f'group {group_id} param {param_id} = {p.numel()}')
  57. state['step'] = 0
  58. #use full precision by default unless self.fp32_optimizer_states is off
  59. state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
  60. #memory_format=torch.preserve_format)
  61. # gradient variances
  62. state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=state_dtype, device='cpu')
  63. #memory_format=torch.preserve_format)
  64. state['step'] += 1
  65. if p.grad.is_sparse == True:
  66. sparse_param = p.sparse_mask(p.grad)
  67. sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad)
  68. self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'],
  69. group['weight_decay'], sparse_param.values(), p.grad.values(),
  70. sparse_exp_avg_sq.values())
  71. p[sparse_param.indices()] = sparse_param.values()
  72. state['exp_avg_sq'][sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values()
  73. if fp16_param_groups is not None:
  74. fp16_param_groups[group_id][param_id][sparse_param.indices()] = sparse_param.values()
  75. else:
  76. if fp16_param_groups is not None:
  77. self.ds_opt_adagrad.adagrad_update_copy(self.opt_id, state['step'], group['lr'], group['eps'],
  78. group['weight_decay'], p.data, p.grad.data,
  79. state['exp_avg_sq'],
  80. fp16_param_groups[group_id][param_id].data)
  81. else:
  82. self.ds_opt_adagrad.adagrad_update(self.opt_id, state['step'], group['lr'], group['eps'],
  83. group['weight_decay'], p.data, p.grad.data,
  84. state['exp_avg_sq'])
  85. return loss