utils.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from typing import List, Tuple
  2. import torch
  3. def is_moe_param(param: torch.Tensor) -> bool:
  4. if hasattr(param, "allreduce") and not param.allreduce:
  5. return True
  6. return False
  7. def split_params_into_shared_and_expert_params(
  8. params: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter,
  9. torch.nn.Parameter]:
  10. shared_params, expert_params = [], []
  11. for p in params:
  12. if is_moe_param(p):
  13. expert_params.append(p)
  14. else:
  15. shared_params.append(p)
  16. return shared_params, expert_params
  17. def split_params_grads_into_shared_and_expert_params(
  18. group: List[torch.nn.Parameter]) -> Tuple[torch.nn.Parameter,
  19. torch.nn.Parameter]:
  20. """Split grad of parameters into grads of non-expert params
  21. and grads of expert params. This is useful while computing
  22. grad-norms for clipping and overflow detection
  23. group (List[torch.nn.Parameter]):
  24. Args:
  25. The group of parameters to split
  26. Returns:
  27. Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]:
  28. list of gradients for non MoE params, list of gradients of MoE params
  29. """
  30. expert_grads = []
  31. shared_grads = []
  32. for p in group:
  33. if p.grad is not None:
  34. if is_moe_param(p):
  35. expert_grads.append(p.grad.to(p.dtype))
  36. else:
  37. shared_grads.append(p.grad.to(p.dtype))
  38. return shared_grads, expert_grads