utils.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.utils import log_dist
  5. # helper function to map between DS policies and DS containers
  6. def policy_to_ds_container(**kwargs):
  7. from .containers import HFGPT2LayerPolicy, DS_GPT2Container
  8. from .containers import HFBertLayerPolicy, DS_BERTContainer
  9. from .containers import BLOOMLayerPolicy, DS_BloomContainer
  10. from .containers import HFGPTJLayerPolicy, DS_GPTJContainer
  11. from .containers import HFGPTNEOLayerPolicy, DS_GPTNEOContainer
  12. from .containers import GPTNEOXLayerPolicy, DS_GPTNEOXContainer
  13. from .containers import HFOPTLayerPolicy, DS_OPTContainer
  14. from .containers import MegatronLayerPolicy, DS_MegatronGPTContainer
  15. from .containers import HFDistilBertLayerPolicy, DS_DistilBERTContainer
  16. from .containers import LLAMALayerPolicy, DS_LLAMAContainer
  17. from .containers import LLAMA2LayerPolicy, DS_LLAMA2Container
  18. from .containers import InternLMLayerPolicy, DS_InternLMContainer
  19. policy_to_container = {
  20. HFGPT2LayerPolicy: DS_GPT2Container,
  21. HFBertLayerPolicy: DS_BERTContainer,
  22. BLOOMLayerPolicy: DS_BloomContainer,
  23. HFGPTJLayerPolicy: DS_GPTJContainer,
  24. HFGPTNEOLayerPolicy: DS_GPTNEOContainer,
  25. GPTNEOXLayerPolicy: DS_GPTNEOXContainer,
  26. HFOPTLayerPolicy: DS_OPTContainer,
  27. MegatronLayerPolicy: DS_MegatronGPTContainer,
  28. HFDistilBertLayerPolicy: DS_DistilBERTContainer,
  29. LLAMALayerPolicy: DS_LLAMAContainer,
  30. LLAMA2LayerPolicy: DS_LLAMA2Container,
  31. InternLMLayerPolicy: DS_InternLMContainer
  32. }
  33. container = None
  34. policy = kwargs['policy']
  35. assert policy is not None, "Policy cannot be None"
  36. policy_type = type(policy)
  37. if policy_type not in policy_to_container:
  38. log_dist(f"Policy type {policy_type} not supported", [0])
  39. else:
  40. container = policy_to_container[policy_type](**kwargs)
  41. return container