auto_tp_model_utils.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed import comm as dist
  5. import torch
  6. from typing import Optional
  7. def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  8. """
  9. Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
  10. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  11. `softmax(l+a) = softmax(l)`. Based on
  12. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  13. TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
  14. Args:
  15. Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
  16. attention_mask (`torch.Tensor`):
  17. Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
  18. num_heads (`int`, *required*):
  19. number of heads
  20. dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
  21. dtype of the output tensor
  22. """
  23. import math
  24. batch_size, seq_length = attention_mask.shape
  25. closest_power_of_2 = 2**math.floor(math.log2(num_heads))
  26. base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  27. device=attention_mask.device,
  28. dtype=torch.float32)
  29. powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
  30. slopes = torch.pow(base, powers)
  31. if closest_power_of_2 != num_heads:
  32. extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  33. device=attention_mask.device,
  34. dtype=torch.float32)
  35. num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
  36. extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
  37. slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
  38. # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
  39. # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
  40. # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
  41. # => the query_length dimension will then be broadcasted correctly
  42. # This is more or less identical to T5's relative position bias:
  43. # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
  44. arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
  45. alibi = slopes[..., None] * arange_tensor
  46. if dist.is_initialized():
  47. num_heads_per_rank = int(num_heads / dist.get_world_size())
  48. offset = dist.get_rank() * num_heads_per_rank
  49. alibi = alibi.view(batch_size, num_heads, 1, seq_length)
  50. alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
  51. return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
  52. else:
  53. return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
  54. def build_mpt_atten_bias_tensor(self,
  55. device,
  56. dtype,
  57. attention_mask: Optional[torch.ByteTensor] = None,
  58. prefix_mask: Optional[torch.ByteTensor] = None,
  59. sequence_id: Optional[torch.LongTensor] = None):
  60. (attn_bias, attention_mask) = self._attn_bias_orig(device,
  61. dtype,
  62. attention_mask=attention_mask,
  63. prefix_mask=prefix_mask,
  64. sequence_id=sequence_id)
  65. if dist.is_initialized():
  66. num_heads_per_rank = int(self.config.n_heads / dist.get_world_size())
  67. offset = dist.get_rank() * num_heads_per_rank
  68. attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
  69. return attn_bias, attention_mask
  70. def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None) -> torch.Tensor:
  71. r"""
  72. Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
  73. relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
  74. the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
  75. https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
  76. """
  77. alibi = self.build_mpt_alibi_tensor_orig(num_heads, sequence_length, alibi_bias_max, device)
  78. if dist.is_initialized():
  79. num_heads_per_rank = int(num_heads / dist.get_world_size())
  80. offset = dist.get_rank() * num_heads_per_rank
  81. alibi = alibi[offset:num_heads_per_rank + offset, :, :]
  82. return alibi