utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from typing import List, Tuple, Dict
  5. import torch
  6. from .layer import MoE
  7. def has_moe_layers(m):
  8. has_moe = False
  9. num_experts = 0
  10. for _, module in m.named_modules():
  11. if isinstance(module, MoE):
  12. has_moe = True
  13. num_experts = module.num_experts
  14. break
  15. return has_moe, num_experts
  16. def is_moe_param(param: torch.Tensor) -> bool:
  17. if hasattr(param, "allreduce") and not param.allreduce:
  18. return True
  19. return False
  20. def split_params_into_shared_and_expert_params(
  21. params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
  22. shared_params, expert_params = [], []
  23. for p in params:
  24. if is_moe_param(p):
  25. expert_params.append(p)
  26. else:
  27. shared_params.append(p)
  28. return shared_params, expert_params
  29. def split_params_grads_into_shared_and_expert_params(
  30. group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
  31. """Split grad of parameters into grads of non-expert params
  32. and grads of expert params. This is useful while computing
  33. grad-norms for clipping and overflow detection
  34. group (List[torch.nn.Parameter]):
  35. Args:
  36. The group of parameters to split
  37. Returns:
  38. Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]:
  39. list of gradients for non MoE params, list of gradients of MoE params
  40. """
  41. expert_grads = []
  42. shared_grads = []
  43. for p in group:
  44. if p.grad is not None:
  45. if is_moe_param(p):
  46. expert_grads.append(p.grad.to(p.dtype))
  47. else:
  48. shared_grads.append(p.grad.to(p.dtype))
  49. return shared_grads, expert_grads
  50. def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
  51. max_group_size=178956971) -> Tuple[Dict]:
  52. """Split parameters into different MoE groups for optimizer
  53. Args:
  54. param_groups (Tuple[Dict]):
  55. The list of parameter groups to split
  56. Returns:
  57. Tuple[Dict]:
  58. list of MoE/non-MoE groups for optimizer
  59. """
  60. if isinstance(param_groups, tuple):
  61. param_groups = list(param_groups) # Tuple cannot be modified
  62. elif isinstance(param_groups, dict):
  63. param_groups = [param_groups]
  64. elif not isinstance(param_groups, list):
  65. raise ValueError(f"Unknown param group type of {type(param_groups)}")
  66. # gather all data parallel group names
  67. data_parallel_group_names = set()
  68. for param_group in param_groups:
  69. for param in param_group["params"]:
  70. if is_moe_param(param):
  71. data_parallel_group_names.add(param.group_name)
  72. data_parallel_group_names = list(data_parallel_group_names)
  73. group_moe = {}
  74. # Create the param MoE groups, leave param assign to next step
  75. for param_group in param_groups:
  76. group_moe[param_group['name']] = {}
  77. for key in data_parallel_group_names:
  78. group_moe[param_group['name']][key] = {}
  79. group_moe[param_group['name']][key]['name'] = key
  80. group_moe[param_group['name']][key]['moe'] = True
  81. for ori_key in param_group.keys():
  82. if ori_key != 'name':
  83. if ori_key == 'params':
  84. group_moe[param_group['name']][key][ori_key] = []
  85. else:
  86. group_moe[param_group['name']][key][ori_key] = param_group[ori_key]
  87. # Assign param
  88. for param_group in param_groups:
  89. new_params = []
  90. for param in param_group['params']:
  91. if is_moe_param(param):
  92. group_moe[param_group['name']][param.group_name]['params'].append(param)
  93. # param_group['params'].remove(param)
  94. else:
  95. new_params.append(param)
  96. param_group['params'] = new_params
  97. # Flatten the moe groups
  98. if max_group_size is not None:
  99. for k, v in group_moe.items():
  100. for k1, v1 in v.items():
  101. cur_group = []
  102. all_groups = []
  103. size_of_cur_group = 0
  104. for param in v1['params']:
  105. if size_of_cur_group + param.numel() <= max_group_size:
  106. cur_group.append(param)
  107. size_of_cur_group += param.numel()
  108. else:
  109. all_groups.append(cur_group)
  110. cur_group = [param]
  111. size_of_cur_group = param.numel()
  112. if cur_group:
  113. all_groups.append(cur_group)
  114. for group in all_groups:
  115. new_dict = {}
  116. for key, val in v1.items():
  117. if key != 'params':
  118. new_dict[key] = val
  119. new_dict['params'] = group
  120. param_groups.append(new_dict)
  121. else:
  122. for k, v in group_moe.items():
  123. for k1, v1 in v.items():
  124. param_groups.append(v1)
  125. return tuple(param_groups)