fused_adam.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .builder import NPUOpBuilder
  5. try:
  6. import torch_npu
  7. except ImportError as e:
  8. pass
  9. class NPUFusedAdam:
  10. @staticmethod
  11. def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode,
  12. bias_correction, weight_decay, *args):
  13. bias_correction1 = beta1**(step - 1)
  14. bias_correction2 = beta2**(step - 1)
  15. # iteration group['params']
  16. for i in range(len(tensor_lists[0])):
  17. grad_flat = tensor_lists[0][i]
  18. param_flat = tensor_lists[1][i]
  19. m_flat = tensor_lists[2][i]
  20. v_flat = tensor_lists[3][i]
  21. if adam_w_mode:
  22. param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam_w(
  23. bias_correction1,
  24. bias_correction2,
  25. lr,
  26. weight_decay,
  27. beta1,
  28. beta2,
  29. epsilon,
  30. grad_flat,
  31. None, # max_grad_norm
  32. False, # amsgrad
  33. False, # maximize
  34. out=(param_flat.data, m_flat, v_flat))
  35. else:
  36. param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam(
  37. bias_correction1,
  38. bias_correction2,
  39. lr,
  40. beta1,
  41. beta2,
  42. epsilon,
  43. grad_flat,
  44. False, # use_locking
  45. False, # use_nesterov
  46. out=(param_flat.data, m_flat, v_flat))
  47. class FusedAdamBuilder(NPUOpBuilder):
  48. BUILD_VAR = "DS_BUILD_FUSED_ADAM"
  49. NAME = "fused_adam"
  50. def __init__(self):
  51. super().__init__(name=self.NAME)
  52. def absolute_name(self):
  53. return f'deepspeed.ops.adam.{self.NAME}_op'
  54. def sources(self):
  55. return []
  56. def include_paths(self):
  57. return []
  58. def load(self, verbose=True):
  59. return NPUFusedAdam