1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import pytest
- import torch
- import deepspeed.comm as dist
- from deepspeed import initialize
- from transformers import AutoModel
- from unit.common import DistributedTest
- from deepspeed.sequence.layer import _SeqAllToAll
- from unit.util import skip_on_arch
- #Use mesh device to create data and sequence parallel group
- class TestUlyssesUtils(DistributedTest):
- world_size = 4
- def test_mesh_device_creation(self) -> None:
- skip_on_arch(min_arch=8)
- model = AutoModel.from_pretrained('bert-base-uncased')
- sp_size = 2
- dp_size = 2
- ds_engine, _, _, _ = initialize(
- model=model,
- config_params={
- "train_batch_size": 8,
- "data_parallel_size": dp_size,
- "sequence_parallel_size": sp_size
- },
- )
- assert ds_engine.seq_parallel_group is not None
- assert ds_engine.data_parallel_group is not None
- assert dist.get_world_size(group=ds_engine.seq_parallel_group) == sp_size
- assert dist.get_world_size(group=ds_engine.data_parallel_group) == dp_size
- assert dist.get_world_size() == sp_size * dp_size
- #Sweep b,s,h,d to test all2all consistency
- @pytest.mark.parametrize("d0", [2, 4]) #batch or sequence dimension
- @pytest.mark.parametrize("d1", [4, 8]) #batch or sequence dimension
- @pytest.mark.parametrize("num_heads", [4, 8])
- @pytest.mark.parametrize("head_dim", [16, 32])
- class TestUlyssesAll2All(DistributedTest):
- world_size = 4
- def test_alltoall_output_consistency(self, d0: int, d1: int, head_dim: int, num_heads: int) -> None:
- skip_on_arch(min_arch=8)
- model = AutoModel.from_pretrained('bert-base-uncased')
- ds_engine, _, _, _ = initialize(model=model, config_params={"train_batch_size": 8}, mesh_param=(2, 2))
- #4D tensor : b,s,h,d or s,b,h,d
- input_tensor = torch.randn(d0, d1, num_heads, head_dim, device=ds_engine.device)
- scatter_idx = 2
- batch_dim_idx = 0
- outputs = []
- seq_dims = [0] #seq first API
- #TODO: Add support for batch first (that seq_dims=[0,1]) after PR for bs>1 issue with batch first is fixed
- ## See discussion in : https://github.com/microsoft/DeepSpeed/issues/5808
- for seq_dim in seq_dims:
- gather_idx = seq_dim
- #first all2all: sequence parallel to head parallel
- s2h_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, input_tensor, scatter_idx, gather_idx,
- batch_dim_idx)
- #No op
- # second all2all: head parallel to sequence parallel
- h2s_tensor = _SeqAllToAll.apply(ds_engine.seq_parallel_group, s2h_tensor, gather_idx, scatter_idx,
- batch_dim_idx)
- print(
- f'[{dist.get_rank()}] s={seq_dim} input: {input_tensor.shape} s2h: {s2h_tensor.shape} h2s_tensor: {h2s_tensor.shape}'
- )
- outputs.append(h2s_tensor)
- # Check outputs are the same as input
- for i in range(1, len(outputs)):
- assert torch.allclose(input_tensor, outputs[i]), f"Outputs differ for sequence dim {seq_dims[i]}"
|