attention_base.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from typing import Any, Dict, Optional, Tuple, Type
  5. import torch
  6. from ...ragged import RaggedBatchWrapper
  7. from deepspeed.runtime.config_utils import DeepSpeedConfigModel
  8. from ..ds_module import DSModuleBase
  9. from ..module_registry import DSModuleRegistryBase
  10. from ..configs import DSSelfAttentionConfig
  11. class DSSelfAttentionBase(DSModuleBase):
  12. """
  13. Base mixin for all attention modules. The interface represented by this module
  14. is broadly:
  15. output = attention(query_key_value,
  16. Optional[kv_cache],
  17. Optional[attention_mask],
  18. Optional[attention_bias])
  19. """
  20. @staticmethod
  21. def config_class() -> Type[DeepSpeedConfigModel]:
  22. return DSSelfAttentionConfig
  23. def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[str, Any]) -> None:
  24. super().__init__(config, implementation_config)
  25. @property
  26. def kv_block_size(self) -> int:
  27. """
  28. Return preferred granulatity for blocked KV-cache implementation.
  29. """
  30. raise NotImplementedError()
  31. @property
  32. def q_block_size(self) -> int:
  33. """
  34. Property to calculate blocking granularity for the query dimension.
  35. This has no impact on the KV-cache structure, but will affect the
  36. number of attention atoms associated with a batch.
  37. """
  38. raise NotImplementedError()
  39. def build_atoms(self, ragged_batch: RaggedBatchWrapper) -> None:
  40. """
  41. Build the atoms for this module. This is not a strict requirement for the class,
  42. so this method is a no-op by default rather than abstract.
  43. """
  44. pass
  45. def forward(self,
  46. q_k_v: torch.Tensor,
  47. kv_cache: torch.Tensor,
  48. batch: RaggedBatchWrapper,
  49. attention_mask: Optional[torch.Tensor] = None,
  50. attention_bias: Optional[torch.Tensor] = None,
  51. inv_freqs: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
  52. """
  53. Parameters:
  54. q_k_v (torch.Tensor): Query, key, and value tensors. Expected shape is:
  55. [
  56. batch,
  57. seq_len,
  58. 2 * self._config.n_heads_kv + self._config.n_heads_q,
  59. self._config.head_size
  60. ].
  61. kv_cache (Optional[torch.Tensor]): Key and value cache tensor. Expected shape is
  62. [
  63. 2,
  64. batch,
  65. kv_cache_len,
  66. self._config.n_heads_kv,
  67. self._config.head_size
  68. ]. If None, cache is disabled. The `kv_cache_len` dimension does not need to
  69. be contiguous (it should expand stride by `max_out_tokens`).
  70. batch (RaggedBatchWrapper): Ragged batch metadata.
  71. attention_mask (Optional[torch.Tensor]): Attention mask tensor. If None, masking is
  72. disabled. This will defer to the config in the case of conflicting information.
  73. This means if the config class is implying causal attention, the mask will be ignored.
  74. attention_bias (Optional[torch.Tensor]): Attention bias tensor. If None, bias is disabled.
  75. """
  76. raise NotImplementedError()
  77. class DSSelfAttentionRegistry(DSModuleRegistryBase):
  78. registry: Dict = {}
  79. @staticmethod
  80. def associated_class() -> Type[DSModuleBase]:
  81. return DSSelfAttentionBase