# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from .builder import NPUOpBuilder try: import torch_npu except ImportError as e: pass class NPUFusedAdam: @staticmethod def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, bias_correction, weight_decay, *args): bias_correction1 = beta1**(step - 1) bias_correction2 = beta2**(step - 1) # iteration group['params'] for i in range(len(tensor_lists[0])): grad_flat = tensor_lists[0][i] param_flat = tensor_lists[1][i] m_flat = tensor_lists[2][i] v_flat = tensor_lists[3][i] if adam_w_mode: param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam_w( bias_correction1, bias_correction2, lr, weight_decay, beta1, beta2, epsilon, grad_flat, None, # max_grad_norm False, # amsgrad False, # maximize out=(param_flat.data, m_flat, v_flat)) else: param_flat.data, m_flat, v_flat = torch_npu.npu_apply_adam( bias_correction1, bias_correction2, lr, beta1, beta2, epsilon, grad_flat, False, # use_locking False, # use_nesterov out=(param_flat.data, m_flat, v_flat)) class FusedAdamBuilder(NPUOpBuilder): BUILD_VAR = "DS_BUILD_FUSED_ADAM" NAME = "fused_adam" def __init__(self): super().__init__(name=self.NAME) def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): return [] def include_paths(self): return [] def load(self, verbose=True): return NPUFusedAdam