mixed_precision_linkage.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. """
  4. import types
  5. from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
  6. def link_hp_params(lp_param_list,
  7. flat_hp_partition,
  8. gradient_dict,
  9. offload_gradient_dict,
  10. use_offload,
  11. param_group_index,
  12. partition_start,
  13. partition_size,
  14. partition_optimizer_state,
  15. dp_group):
  16. local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list,
  17. partition_start,
  18. partition_size,
  19. dp_group)
  20. for lp_param, lp_start in local_lp_param_and_offset:
  21. lp_param._hp_mapping = get_hp_fragment_mapping(lp_param,
  22. lp_start,
  23. flat_hp_partition,
  24. gradient_dict,
  25. offload_gradient_dict,
  26. use_offload,
  27. param_group_index,
  28. partition_start,
  29. partition_size,
  30. partition_optimizer_state)
  31. def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):
  32. current_offset = 0
  33. param_and_offset_list = []
  34. partition_end = partition_start + partition_size
  35. index_in_param_group = 0
  36. for i, lp_param in enumerate(lp_param_list):
  37. lp_param._hp_mapping = None
  38. lp_param._dp_group = dp_group
  39. lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
  40. lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
  41. # lp_param overlaps with partition if both are true
  42. # 1) current_offset < partition_end,
  43. # 2) current_offset + lp_param.numel() >= partition_start
  44. lp_param_end = current_offset + lp_param.numel()
  45. if current_offset < partition_end and lp_param_end > partition_start:
  46. param_and_offset_list.append((lp_param, current_offset))
  47. lp_param._index_in_param_group = index_in_param_group
  48. # Indices for params in this partition/GPU
  49. index_in_param_group += 1
  50. current_offset += lp_param.numel()
  51. return param_and_offset_list