cpu_adagrad.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import math
  5. import torch
  6. import time
  7. from pathlib import Path
  8. from ..op_builder import CPUAdagradBuilder
  9. from deepspeed.utils.logging import should_log_le
  10. class DeepSpeedCPUAdagrad(torch.optim.Optimizer):
  11. optimizer_id = 0
  12. def __init__(self,
  13. model_params,
  14. lr=1e-2,
  15. eps=1e-10,
  16. weight_decay=0,
  17. amsgrad=False,
  18. fp32_optimizer_states=True):
  19. default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
  20. super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args)
  21. self.opt_id = DeepSpeedCPUAdagrad.optimizer_id
  22. DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1
  23. self.fp32_optimizer_states = fp32_optimizer_states
  24. self.ds_opt_adagrad = CPUAdagradBuilder().load()
  25. self.ds_opt_adagrad.create_adagrad(self.opt_id,
  26. lr,
  27. eps,
  28. weight_decay,
  29. should_log_le("info"))
  30. def __del__(self):
  31. # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize
  32. # is used multiple times in the same process (notebook or pytest worker)
  33. self.ds_opt_adagrad.destroy_adagrad(self.opt_id)
  34. def __setstate__(self, state):
  35. super(DeepSpeedCPUAdagrad, self).__setstate__(state)
  36. for group in self.param_groups:
  37. group.setdefault('amsgrad', False)
  38. @torch.no_grad()
  39. def step(self, closure=None, fp16_param_groups=None):
  40. """Update the model parameters.
  41. .. note::
  42. This method will be called internally by ZeRO-Offload. DeepSpeed
  43. users should still use ``engine.step()`` as shown in the
  44. `Getting Started
  45. <https://www.deepspeed.ai/getting-started/#training>`_ guide.
  46. Args:
  47. closure (callable, optional): closure to compute the loss.
  48. Defaults to ``None``.
  49. fp16_param_groups: FP16 GPU parameters to update. Performing the
  50. copy here reduces communication time. Defaults to ``None``.
  51. Returns:
  52. loss: if ``closure`` is provided. Otherwise ``None``.
  53. """
  54. loss = None
  55. if closure is not None:
  56. with torch.enable_grad():
  57. loss = closure()
  58. for group_id, group in enumerate(self.param_groups):
  59. for param_id, p in enumerate(group['params']):
  60. if p.grad is None:
  61. continue
  62. state = self.state[p]
  63. # State initialization
  64. if len(state) == 0:
  65. #print(f'group {group_id} param {param_id} = {p.numel()}')
  66. state['step'] = 0
  67. #use full precision by default unless self.fp32_optimizer_states is off
  68. state_dtype = torch.float if self.fp32_optimizer_states else p.dtype
  69. #memory_format=torch.preserve_format)
  70. # gradient variances
  71. state['exp_avg_sq'] = torch.zeros_like(p.data,
  72. dtype=state_dtype,
  73. device='cpu')
  74. #memory_format=torch.preserve_format)
  75. state['step'] += 1
  76. if p.grad.is_sparse == True:
  77. sparse_param = p.sparse_mask(p.grad)
  78. sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad)
  79. self.ds_opt_adagrad.adagrad_update(self.opt_id,
  80. state['step'],
  81. group['lr'],
  82. group['eps'],
  83. group['weight_decay'],
  84. sparse_param.values(),
  85. p.grad.values(),
  86. sparse_exp_avg_sq.values())
  87. p[sparse_param.indices()] = sparse_param.values()
  88. state['exp_avg_sq'][
  89. sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values()
  90. if fp16_param_groups is not None:
  91. fp16_param_groups[group_id][param_id][
  92. sparse_param.indices()] = sparse_param.values()
  93. else:
  94. if fp16_param_groups is not None:
  95. self.ds_opt_adagrad.adagrad_update_copy(
  96. self.opt_id,
  97. state['step'],
  98. group['lr'],
  99. group['eps'],
  100. group['weight_decay'],
  101. p.data,
  102. p.grad.data,
  103. state['exp_avg_sq'],
  104. fp16_param_groups[group_id][param_id].data)
  105. else:
  106. self.ds_opt_adagrad.adagrad_update(self.opt_id,
  107. state['step'],
  108. group['lr'],
  109. group['eps'],
  110. group['weight_decay'],
  111. p.data,
  112. p.grad.data,
  113. state['exp_avg_sq'])
  114. return loss