mixed_precision_linkage.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import types
  5. from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
  6. from deepspeed.utils import set_full_hp_param
  7. def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
  8. param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group):
  9. local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group)
  10. for lp_param, lp_start in local_lp_param_and_offset:
  11. lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict,
  12. offload_gradient_dict, use_offload, param_group_index,
  13. partition_start, partition_size, partition_optimizer_state)
  14. def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):
  15. current_offset = 0
  16. param_and_offset_list = []
  17. partition_end = partition_start + partition_size
  18. index_in_param_group = 0
  19. for i, lp_param in enumerate(lp_param_list):
  20. lp_param._hp_mapping = None
  21. lp_param._dp_group = dp_group
  22. lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
  23. lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
  24. lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)
  25. # lp_param overlaps with partition if both are true
  26. # 1) current_offset < partition_end,
  27. # 2) current_offset + lp_param.numel() >= partition_start
  28. lp_param_end = current_offset + lp_param.numel()
  29. if current_offset < partition_end and lp_param_end > partition_start:
  30. param_and_offset_list.append((lp_param, current_offset))
  31. lp_param._index_in_param_group = index_in_param_group
  32. # Indices for params in this partition/GPU
  33. index_in_param_group += 1
  34. current_offset += lp_param.numel()
  35. return param_and_offset_list