z3_leaf_module.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from typing import List, Type
  6. def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type], flag: bool) -> None:
  7. assert all(isinstance(module_class, type) for module_class in leaf_module_classes), \
  8. f'leaf_module_classes must be a list of types, got {leaf_module_classes}'
  9. def _set_z3_leaf_flag(model: torch.nn.Module):
  10. if model.__class__ in leaf_module_classes:
  11. model._z3_leaf = flag
  12. model.apply(_set_z3_leaf_flag)
  13. def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:
  14. """Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
  15. This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
  16. Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
  17. Args:
  18. model (torch.nn.Module): The model to which the leaf module flag will be applied.
  19. leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
  20. """
  21. _do_set_z3_leaf_modules(model, leaf_module_classes, True)
  22. def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> None:
  23. """Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
  24. See `set_z3_leaf_modules` for more details.
  25. Args:
  26. model (torch.nn.Module): The model to which the leaf module flag will be applied.
  27. leaf_module_classes (List[Type]): A list of module classes that should be flagged as 'leaf' modules.
  28. """
  29. _do_set_z3_leaf_modules(model, leaf_module_classes, False)
  30. def z3_leaf_module(model: torch.nn.Module) -> bool:
  31. """Returns whether a module in `model` has been flagged as a 'leaf' module.
  32. See `set_z3_leaf_modules` for more details.
  33. Args:
  34. model (torch.nn.Module): The model to which the leaf module flag will be applied.
  35. """
  36. return hasattr(model, '_z3_leaf') and model._z3_leaf