replace_policy.py 989 B

123456789101112131415161718192021222324252627
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .containers import HFGPT2LayerPolicy
  5. from .containers import HFBertLayerPolicy
  6. from .containers import BLOOMLayerPolicy
  7. from .containers import HFGPTJLayerPolicy
  8. from .containers import HFGPTNEOLayerPolicy
  9. from .containers import GPTNEOXLayerPolicy
  10. from .containers import HFOPTLayerPolicy
  11. from .containers import MegatronLayerPolicy
  12. from .containers import HFDistilBertLayerPolicy
  13. from .containers import HFCLIPLayerPolicy
  14. from .containers import LLAMALayerPolicy
  15. from .containers import UNetPolicy
  16. from .containers import VAEPolicy
  17. # transformer-based policies
  18. replace_policies = [
  19. HFBertLayerPolicy, HFGPTNEOLayerPolicy, GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy,
  20. HFGPT2LayerPolicy, BLOOMLayerPolicy, HFOPTLayerPolicy, HFCLIPLayerPolicy, HFDistilBertLayerPolicy, LLAMALayerPolicy
  21. ]
  22. # non-transformer-based policies
  23. generic_policies = [UNetPolicy, VAEPolicy]