auto_tp_model_utils.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed import comm as dist
  5. import torch
  6. from typing import Optional
  7. from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
  8. def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  9. """
  10. Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
  11. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  12. `softmax(l+a) = softmax(l)`. Based on
  13. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  14. TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
  15. Args:
  16. Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
  17. attention_mask (`torch.Tensor`):
  18. Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
  19. num_heads (`int`, *required*):
  20. number of heads
  21. dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
  22. dtype of the output tensor
  23. """
  24. import math
  25. batch_size, seq_length = attention_mask.shape
  26. closest_power_of_2 = 2**math.floor(math.log2(num_heads))
  27. base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  28. device=attention_mask.device,
  29. dtype=torch.float32)
  30. powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
  31. slopes = torch.pow(base, powers)
  32. if closest_power_of_2 != num_heads:
  33. extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  34. device=attention_mask.device,
  35. dtype=torch.float32)
  36. num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
  37. extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
  38. slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
  39. # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
  40. # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
  41. # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
  42. # => the query_length dimension will then be broadcasted correctly
  43. # This is more or less identical to T5's relative position bias:
  44. # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
  45. arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
  46. alibi = slopes[..., None] * arange_tensor
  47. if dist.is_initialized():
  48. num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size())
  49. offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()])
  50. alibi = alibi.view(batch_size, num_heads, 1, seq_length)
  51. alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
  52. return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
  53. else:
  54. return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
  55. def get_alibi_mask(self, tensor, seq_length_with_past):
  56. mask = self.get_alibi_mask_orig(tensor, seq_length_with_past)
  57. if not self.training and dist.is_initialized():
  58. num_heads_per_rank = get_shard_size(self.n_head, dist.get_world_size())
  59. offset = sum(get_shard_size_list(self.n_head, dist.get_world_size())[0:dist.get_rank()])
  60. mask = mask[offset:num_heads_per_rank + offset, :seq_length_with_past, :seq_length_with_past]
  61. return mask
  62. def build_mpt_atten_bias_tensor(self,
  63. device,
  64. dtype,
  65. attention_mask: Optional[torch.ByteTensor] = None,
  66. prefix_mask: Optional[torch.ByteTensor] = None,
  67. sequence_id: Optional[torch.LongTensor] = None):
  68. (attn_bias, attention_mask) = self._attn_bias_orig(device,
  69. dtype,
  70. attention_mask=attention_mask,
  71. prefix_mask=prefix_mask,
  72. sequence_id=sequence_id)
  73. if dist.is_initialized():
  74. num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size())
  75. offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()])
  76. attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
  77. return attn_bias, attention_mask
  78. def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None) -> torch.Tensor:
  79. r"""
  80. Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
  81. relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
  82. the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
  83. https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
  84. """
  85. alibi = self.build_mpt_alibi_tensor_orig(num_heads, sequence_length, alibi_bias_max, device)
  86. if dist.is_initialized():
  87. num_heads_per_rank = int(num_heads / dist.get_world_size())
  88. offset = dist.get_rank() * num_heads_per_rank
  89. alibi = alibi[offset:num_heads_per_rank + offset, :, :]
  90. return alibi