test_ulysses.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. import torch
  6. import deepspeed.comm as dist
  7. from deepspeed import initialize
  8. from transformers import AutoModel
  9. from unit.common import DistributedTest
  10. from deepspeed.sequence.layer import _SeqAllToAll
  11. from unit.util import skip_on_arch
  12. #Use mesh device to create data and sequence parallel group
  13. class TestUlyssesUtils(DistributedTest):
  14. world_size = 4
  15. def test_mesh_device_creation(self) -> None:
  16. skip_on_arch(min_arch=8)
  17. model = AutoModel.from_pretrained('bert-base-uncased')
  18. sp_size = 2
  19. dp_size = 2
  20. ds_engine, _, _, _ = initialize(
  21. model=model,
  22. config_params={
  23. "train_batch_size": 8,
  24. "data_parallel_size": dp_size,
  25. "sequence_parallel_size": sp_size
  26. },
  27. )
  28. assert ds_engine.seq_parallel_group is not None
  29. assert ds_engine.data_parallel_group is not None
  30. assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size
  31. assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size
  32. assert dist.get_world_size() == sp_size * dp_size
  33. #Sweep b,s,h,d to test all2all consistency
  34. @pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension
  35. @pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension
  36. @pytest.mark.parametrize("num_heads", [4, 8])
  37. @pytest.mark.parametrize("head_dim", [16, 32])
  38. class TestUlyssesAll2All(DistributedTest):
  39. world_size = 4
  40. def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:
  41. skip_on_arch(min_arch=8)
  42. model = AutoModel.from_pretrained('bert-base-uncased')
  43. ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2))
  44. #4D tensor : b,s,h,d or s,b,h,d
  45. input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device)
  46. scatter_idx = 2
  47. batch_dim_idx = 0
  48. outputs = []
  49. seq_dims = [0] #seq first API
  50. #TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed
  51. ## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808
  52. for seq_dim in seq_dims:
  53. gather_idx = seq_dim
  54. #first all2all: sequence parallel to head parallel
  55. s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx,
  56. batch_dim_idx)
  57. #No op
  58. # second all2all: head parallel to sequence parallel
  59. h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx,
  60. batch_dim_idx)
  61. print(
  62. f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}'
  63. )
  64. outputs.append(h2s_tensor)
  65. # Check outputs are the same as input
  66. for i in range(1, len(outputs)):
  67. assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"