fused_adam.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright (c) 2023 Habana Labs, Ltd. an Intel Company
  2. # Copyright (c) Microsoft Corporation.
  3. # SPDX-License-Identifier: Apache-2.0
  4. # DeepSpeed Team
  5. try:
  6. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  7. # if successful this also means we're doing a local install and not JIT compile path
  8. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  9. from op_builder.builder import OpBuilder
  10. except ImportError:
  11. from deepspeed.ops.op_builder.builder import OpBuilder
  12. try:
  13. import torch
  14. import math
  15. except ImportError as e:
  16. pass
  17. class HPUFusedAdam:
  18. htcore = None
  19. is_lazy_mode = None
  20. @staticmethod
  21. def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode,
  22. bias_correction, weight_decay, *args):
  23. if HPUFusedAdam.htcore is None:
  24. from habana_frameworks.torch import core as htcore
  25. from habana_frameworks.torch.utils.internal import is_lazy
  26. HPUFusedAdam.htcore = htcore
  27. HPUFusedAdam.is_lazy_mode = is_lazy()
  28. htcore = HPUFusedAdam.htcore
  29. htcore.step_closure._mark_step_if_lazy()
  30. step_size = lr
  31. if bias_correction:
  32. bias_correction1 = 1.0 - pow(beta1, step)
  33. bias_correction2 = 1.0 - pow(beta2, step)
  34. step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
  35. neg_step = -step_size
  36. neg_step_t = (torch.tensor([neg_step], dtype=torch.float,
  37. requires_grad=False).to(tensor_lists[1][0].dtype).to(tensor_lists[1][0].device,
  38. non_blocking=True))
  39. weight_decay = weight_decay if adam_w_mode else 0
  40. # since lr is fed into the kernel as tensor, perform the scalar multiplication of wd here
  41. # NOTE: TODO if lr is updated every step, then we need to convert it as tensor and
  42. # perform weight decay unconditonally.
  43. modified_wd = 1.0 - weight_decay * lr
  44. if HPUFusedAdam.is_lazy_mode:
  45. torch.ops.hpu.optimizer_adamw(
  46. tensor_lists[0],
  47. tensor_lists[1],
  48. tensor_lists[2],
  49. tensor_lists[3],
  50. neg_step_t,
  51. beta1,
  52. beta2,
  53. epsilon,
  54. modified_wd,
  55. )
  56. else:
  57. modified_wd_t = (torch.tensor([modified_wd], dtype=torch.float, requires_grad=False).to(
  58. tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, non_blocking=True))
  59. torch.ops.hpu.optimizer_adamw(
  60. tensor_lists[0],
  61. tensor_lists[1],
  62. tensor_lists[2],
  63. tensor_lists[3],
  64. neg_step_t,
  65. beta1,
  66. beta2,
  67. epsilon,
  68. modified_wd_t,
  69. modified_wd != 1.0,
  70. )
  71. htcore.step_closure._mark_step_if_lazy()
  72. class FusedAdamBuilder(OpBuilder):
  73. BUILD_VAR = "DS_BUILD_FUSED_ADAM"
  74. NAME = "fused_adam"
  75. def __init__(self):
  76. super().__init__(name=self.NAME)
  77. def absolute_name(self):
  78. return f'deepspeed.ops.adam.{self.NAME}_op'
  79. def sources(self):
  80. return []
  81. def include_paths(self):
  82. return []
  83. def load(self, verbose=True):
  84. return HPUFusedAdam