bwc.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. def bwc_tensor_model_parallel_rank(mpu=None):
  5. """Backwards-compatible way of querying the tensor model parallel rank from
  6. an ``mpu`` object.
  7. *Tensor* model parallelism means that tensors are physically split across
  8. processes. This contrasts with *pipeline* model parallelism, in which the
  9. layers are partitioned but tensors left intact.
  10. The API for tensor model parallelism has changed across versions and this
  11. helper provides a best-effort implementation across versions of ``mpu``
  12. objects. The preferred mechanism is
  13. ``mpu.get_tensor_model_parallel_rank()``.
  14. This should "just work" with both Megatron-LM and DeepSpeed's pipeline
  15. parallelism.
  16. Args:
  17. mpu (model parallel unit, optional): The tensor model parallel rank.
  18. If ``mpu=None``, returns 0. Defaults to ``None``.
  19. Returns:
  20. int: the rank
  21. """
  22. if mpu is None:
  23. # No model parallelism in easy :)
  24. return 0
  25. if hasattr(mpu, 'get_tensor_model_parallel_rank'):
  26. # New Megatron and DeepSpeed convention (post pipeline-parallelism release)
  27. return mpu.get_tensor_model_parallel_rank()
  28. elif hasattr(mpu, 'get_slice_parallel_rank'):
  29. # Some DeepSpeed + pipeline parallelism versions
  30. return mpu.get_slice_parallel_rank()
  31. else:
  32. # Deprecated Megatron and DeepSpeed convention
  33. return mpu.get_model_parallel_rank()
  34. def bwc_tensor_model_parallel_world_size(mpu=None):
  35. """Backwards-compatible way of querying the tensor model parallel world size.
  36. Similar to bwc_tensor_model_parallel_rank.
  37. """
  38. if mpu is None:
  39. return 1
  40. if hasattr(mpu, 'get_tensor_model_parallel_world_size'):
  41. # New Megatron and DeepSpeed convention (post pipeline-parallelism release)
  42. return mpu.get_tensor_model_parallel_world_size()
  43. elif hasattr(mpu, 'get_slice_parallel_world_size'):
  44. # Some DeepSpeed + pipeline parallelism versions
  45. return mpu.get_slice_parallel_world_size()
  46. else:
  47. # Deprecated Megatron and DeepSpeed convention
  48. return mpu.get_model_parallel_world_size()
  49. def bwc_tensor_model_parallel_group(mpu=None):
  50. """Backwards-compatible way of querying the tensor model parallel group.
  51. Similar to bwc_tensor_model_parallel_rank.
  52. """
  53. if mpu is None:
  54. return None
  55. if hasattr(mpu, 'get_tensor_model_parallel_group'):
  56. # New Megatron and DeepSpeed convention (post pipeline-parallelism release)
  57. return mpu.get_tensor_model_parallel_group()
  58. elif hasattr(mpu, 'get_slice_parallel_group'):
  59. # Some DeepSpeed + pipeline parallelism versions
  60. return mpu.get_slice_parallel_group()
  61. else:
  62. # Deprecated Megatron and DeepSpeed convention
  63. return mpu.get_model_parallel_group()
  64. def bwc_pipeline_parallel_world_size(mpu=None):
  65. """Backwards-compatible way of querying the pipeline parallel world size."""
  66. world_size = 1
  67. if mpu is not None:
  68. if hasattr(mpu, 'get_pipeline_model_parallel_world_size'):
  69. # New Megatron and DeepSpeed convention (post pipeline-parallelism release)
  70. world_size = mpu.get_pipeline_model_parallel_world_size()
  71. elif hasattr(mpu, 'get_pipe_parallel_world_size'):
  72. # DeepSpeed Topology
  73. world_size = mpu.get_pipe_parallel_world_size()
  74. return world_size
  75. def bwc_pipeline_parallel_group(mpu=None):
  76. """Backwards-compatible way of querying the pipeline parallel group."""
  77. if mpu is None:
  78. return None
  79. if hasattr(mpu, 'get_pipeline_model_parallel_group'):
  80. # Megatron
  81. return mpu.get_pipeline_model_parallel_group()
  82. elif hasattr(mpu, 'get_pipe_parallel_group'):
  83. # DeepSpeed Topology
  84. return mpu.get_pipe_parallel_group()
  85. assert False, 'mpu does not support pipeline parallel group'