replace_policy.py 918 B

1234567891011121314151617181920212223242526272829303132
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. from .containers import HFGPT2LayerPolicy
  5. from .containers import HFBertLayerPolicy
  6. from .containers import BLOOMLayerPolicy
  7. from .containers import HFGPTJLayerPolicy
  8. from .containers import HFGPTNEOLayerPolicy
  9. from .containers import GPTNEOXLayerPolicy
  10. from .containers import HFOPTLayerPolicy
  11. from .containers import MegatronLayerPolicy
  12. from .containers import HFDistilBertLayerPolicy
  13. from .containers import HFCLIPLayerPolicy
  14. from .containers import UNetPolicy
  15. from .containers import VAEPolicy
  16. # transformer-based policies
  17. replace_policies = [
  18. HFBertLayerPolicy,
  19. HFGPTNEOLayerPolicy,
  20. GPTNEOXLayerPolicy,
  21. HFGPTJLayerPolicy,
  22. MegatronLayerPolicy,
  23. HFGPT2LayerPolicy,
  24. BLOOMLayerPolicy,
  25. HFOPTLayerPolicy,
  26. HFCLIPLayerPolicy,
  27. HFDistilBertLayerPolicy
  28. ]
  29. # non-transformer-based policies
  30. generic_policies = [UNetPolicy, VAEPolicy]