utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from typing import List, Tuple, Dict
  3. import torch
  4. from .layer import MoE
  5. def has_moe_layers(m):
  6. has_moe = False
  7. num_experts = 0
  8. for _, module in m.named_modules():
  9. if isinstance(module, MoE):
  10. has_moe = True
  11. num_experts = module.num_experts
  12. break
  13. return has_moe, num_experts
  14. def is_moe_param(param: torch.Tensor) -> bool:
  15. if hasattr(param, "allreduce") and not param.allreduce:
  16. return True
  17. return False
  18. def split_params_into_shared_and_expert_params(
  19. params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter,
  20. torch.nn.Parameter]:
  21. shared_params, expert_params = [], []
  22. for p in params:
  23. if is_moe_param(p):
  24. expert_params.append(p)
  25. else:
  26. shared_params.append(p)
  27. return shared_params, expert_params
  28. def split_params_grads_into_shared_and_expert_params(
  29. group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter,
  30. torch.nn.Parameter]:
  31. """Split grad of parameters into grads of non-expert params
  32. and grads of expert params. This is useful while computing
  33. grad-norms for clipping and overflow detection
  34. group (List[torch.nn.Parameter]):
  35. Args:
  36. The group of parameters to split
  37. Returns:
  38. Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]:
  39. list of gradients for non MoE params, list of gradients of MoE params
  40. """
  41. expert_grads = []
  42. shared_grads = []
  43. for p in group:
  44. if p.grad is not None:
  45. if is_moe_param(p):
  46. expert_grads.append(p.grad.to(p.dtype))
  47. else:
  48. shared_grads.append(p.grad.to(p.dtype))
  49. return shared_grads, expert_grads
  50. def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict],
  51. max_group_size=178956971
  52. ) -> Tuple[Dict]:
  53. """Split parameters into different MoE groups for optimizer
  54. Args:
  55. param_groups (Tuple[Dict]):
  56. The list of parameter groups to split
  57. Returns:
  58. Tuple[Dict]:
  59. list of MoE/non-MoE groups for optimizer
  60. """
  61. if isinstance(param_groups, tuple):
  62. param_groups = list(param_groups) # Tuple cannot be modified
  63. elif isinstance(param_groups, dict):
  64. param_groups = [param_groups]
  65. elif not isinstance(param_groups, list):
  66. raise ValueError(f"Unknown param group type of {type(param_groups)}")
  67. # gather all data parallel group names
  68. data_parallel_group_names = set()
  69. for param_group in param_groups:
  70. for param in param_group["params"]:
  71. if is_moe_param(param):
  72. data_parallel_group_names.add(param.group_name)
  73. data_parallel_group_names = list(data_parallel_group_names)
  74. group_moe = {}
  75. # Create the param MoE groups, leave param assign to next step
  76. for param_group in param_groups:
  77. group_moe[param_group['name']] = {}
  78. for key in data_parallel_group_names:
  79. group_moe[param_group['name']][key] = {}
  80. group_moe[param_group['name']][key]['name'] = key
  81. group_moe[param_group['name']][key]['moe'] = True
  82. for ori_key in param_group.keys():
  83. if ori_key != 'name':
  84. if ori_key == 'params':
  85. group_moe[param_group['name']][key][ori_key] = []
  86. else:
  87. group_moe[
  88. param_group['name']][key][ori_key] = param_group[ori_key]
  89. # Assign param
  90. for param_group in param_groups:
  91. new_params = []
  92. for param in param_group['params']:
  93. if is_moe_param(param):
  94. group_moe[param_group['name']][param.group_name]['params'].append(param)
  95. # param_group['params'].remove(param)
  96. else:
  97. new_params.append(param)
  98. param_group['params'] = new_params
  99. # Flatten the moe groups
  100. if max_group_size is not None:
  101. for k, v in group_moe.items():
  102. for k1, v1 in v.items():
  103. cur_group = []
  104. all_groups = []
  105. size_of_cur_group = 0
  106. for param in v1['params']:
  107. if size_of_cur_group + param.numel() <= max_group_size:
  108. cur_group.append(param)
  109. size_of_cur_group += param.numel()
  110. else:
  111. all_groups.append(cur_group)
  112. cur_group = [param]
  113. size_of_cur_group = param.numel()
  114. if cur_group:
  115. all_groups.append(cur_group)
  116. for group in all_groups:
  117. new_dict = {}
  118. for key, val in v1.items():
  119. if key != 'params':
  120. new_dict[key] = val
  121. new_dict['params'] = group
  122. param_groups.append(new_dict)
  123. else:
  124. for k, v in group_moe.items():
  125. for k1, v1 in v.items():
  126. param_groups.append(v1)
  127. return tuple(param_groups)