replace_policy.py 1.1 KB

123456789101112131415161718192021222324252627282930
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .containers import HFGPT2LayerPolicy
  5. from .containers import HFBertLayerPolicy
  6. from .containers import BLOOMLayerPolicy
  7. from .containers import HFGPTJLayerPolicy
  8. from .containers import HFGPTNEOLayerPolicy
  9. from .containers import GPTNEOXLayerPolicy
  10. from .containers import HFOPTLayerPolicy
  11. from .containers import MegatronLayerPolicy
  12. from .containers import HFDistilBertLayerPolicy
  13. from .containers import HFCLIPLayerPolicy
  14. from .containers import LLAMALayerPolicy
  15. from .containers import UNetPolicy
  16. from .containers import VAEPolicy
  17. from .containers import LLAMA2LayerPolicy
  18. from .containers import InternLMLayerPolicy
  19. # transformer-based policies
  20. replace_policies = [
  21. HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy,
  22. HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy,
  23. LLAMALayerPolicy, LLAMA2LayerPolicy, InternLMLayerPolicy
  24. ]
  25. # non-transformer-based policies
  26. generic_policies = [UNetPolicy, VAEPolicy]