utils.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from deepspeed.utils import log_dist
  3. # helper function to map between DS policies and DS containers
  4. def policy_to_ds_container(**kwargs):
  5. from .containers import HFGPT2LayerPolicy, DS_GPT2Container
  6. from .containers import HFBertLayerPolicy, DS_BERTContainer
  7. from .containers import BLOOMLayerPolicy, DS_BloomContainer
  8. from .containers import HFGPTJLayerPolicy, DS_GPTJContainer
  9. from .containers import HFGPTNEOLayerPolicy, DS_GPTNEOContainer
  10. from .containers import GPTNEOXLayerPolicy, DS_GPTNEOXContainer
  11. from .containers import HFOPTLayerPolicy, DS_OPTContainer
  12. from .containers import MegatronLayerPolicy, DS_MegatronGPTContainer
  13. from .containers import HFDistilBertLayerPolicy, DS_DistilBERTContainer
  14. policy_to_container = {
  15. HFGPT2LayerPolicy: DS_GPT2Container,
  16. HFBertLayerPolicy: DS_BERTContainer,
  17. BLOOMLayerPolicy: DS_BloomContainer,
  18. HFGPTJLayerPolicy: DS_GPTJContainer,
  19. HFGPTNEOLayerPolicy: DS_GPTNEOContainer,
  20. GPTNEOXLayerPolicy: DS_GPTNEOXContainer,
  21. HFOPTLayerPolicy: DS_OPTContainer,
  22. MegatronLayerPolicy: DS_MegatronGPTContainer,
  23. HFDistilBertLayerPolicy: DS_DistilBERTContainer,
  24. }
  25. container = None
  26. policy = kwargs['policy']
  27. assert policy is not None, "Policy cannot be None"
  28. policy_type = type(policy)
  29. if policy_type not in policy_to_container:
  30. log_dist(f"Policy type {policy_type} not supported", [0])
  31. else:
  32. container = policy_to_container[policy_type](**kwargs)
  33. return container