mixed_precision_linkage.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import types
  5. from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
  6. def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
  7. param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group):
  8. local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group)
  9. for lp_param, lp_start in local_lp_param_and_offset:
  10. lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict,
  11. offload_gradient_dict, use_offload, param_group_index,
  12. partition_start, partition_size, partition_optimizer_state)
  13. def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):
  14. current_offset = 0
  15. param_and_offset_list = []
  16. partition_end = partition_start + partition_size
  17. index_in_param_group = 0
  18. for i, lp_param in enumerate(lp_param_list):
  19. lp_param._hp_mapping = None
  20. lp_param._dp_group = dp_group
  21. lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
  22. lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
  23. # lp_param overlaps with partition if both are true
  24. # 1) current_offset < partition_end,
  25. # 2) current_offset + lp_param.numel() >= partition_start
  26. lp_param_end = current_offset + lp_param.numel()
  27. if current_offset < partition_end and lp_param_end > partition_start:
  28. param_and_offset_list.append((lp_param, current_offset))
  29. lp_param._index_in_param_group = index_in_param_group
  30. # Indices for params in this partition/GPU
  31. index_in_param_group += 1
  32. current_offset += lp_param.numel()
  33. return param_and_offset_list