utils.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from typing import List, Tuple
  2. import torch
  3. def is_moe_param(param: torch.Tensor) -> bool:
  4. if hasattr(param, "allreduce") and not param.allreduce:
  5. return True
  6. return False
  7. def split_params_into_shared_and_expert_params(
  8. params: List[torch.nn.Parameter]
  9. ) -> Tuple[torch.nn.Parameter,
  10. torch.nn.Parameter]:
  11. shared_params, expert_params = [], []
  12. for p in params:
  13. if is_moe_param(p):
  14. expert_params.append(p)
  15. else:
  16. shared_params.append(p)
  17. return shared_params, expert_params
  18. def split_params_grads_into_shared_and_expert_params(
  19. group: List[torch.nn.Parameter]
  20. ) -> Tuple[torch.nn.Parameter,
  21. torch.nn.Parameter]:
  22. """Split grad of parameters into grads of non-expert params
  23. and grads of expert params. This is useful while computing
  24. grad-norms for clipping and overflow detection
  25. group (List[torch.nn.Parameter]):
  26. Args:
  27. The group of parameters to split
  28. Returns:
  29. Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]:
  30. list of gradients for non MoE params, list of gradients of MoE params
  31. """
  32. expert_grads = []
  33. shared_grads = []
  34. for p in group:
  35. if p.grad is not None:
  36. if is_moe_param(p):
  37. expert_grads.append(p.grad.to(p.dtype))
  38. else:
  39. shared_grads.append(p.grad.to(p.dtype))
  40. return shared_grads, expert_grads