mixed_precision_linkage.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import types
  5. from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
  6. from deepspeed.utils import set_full_hp_param
  7. def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
  8. param_group_index, partition_start, partition_size, dp_group):
  9. local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group)
  10. for lp_param, lp_start in local_lp_param_and_offset:
  11. lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict,
  12. offload_gradient_dict, use_offload, param_group_index,
  13. partition_start, partition_size)
  14. def lazy_init_hp_params_optimizer_state(lp_param_list, flat_hp_partition, optimizer_state):
  15. for lp in lp_param_list:
  16. if lp._hp_mapping is not None:
  17. lp._hp_mapping.set_optim_state_fragment(flat_hp_partition, optimizer_state[flat_hp_partition])
  18. def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group):
  19. current_offset = 0
  20. param_and_offset_list = []
  21. partition_end = partition_start + partition_size
  22. index_in_param_group = 0
  23. for i, lp_param in enumerate(lp_param_list):
  24. lp_param._hp_mapping = None
  25. lp_param._dp_group = dp_group
  26. lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
  27. lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
  28. lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)
  29. # lp_param overlaps with partition if both are true
  30. # 1) current_offset < partition_end,
  31. # 2) current_offset + lp_param.numel() >= partition_start
  32. lp_param_end = current_offset + lp_param.numel()
  33. if current_offset < partition_end and lp_param_end > partition_start:
  34. param_and_offset_list.append((lp_param, current_offset))
  35. lp_param._index_in_param_group = index_in_param_group
  36. # Indices for params in this partition/GPU
  37. index_in_param_group += 1
  38. current_offset += lp_param.numel()
  39. return param_and_offset_list