engine_v2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import json
  6. import pickle
  7. from typing import Iterable, Tuple
  8. import torch
  9. import deepspeed.comm as dist
  10. from deepspeed.accelerator import get_accelerator
  11. from deepspeed.comm.comm import init_distributed
  12. from .model_implementations import InferenceV2Policy
  13. from .logging import inference_logger
  14. from .ragged import DSStateManager, RaggedBatchWrapper, PlaceholderSequenceDescriptor
  15. from .scheduling_utils import SchedulingError, SchedulingResult
  16. from .model_implementations.flat_model_helpers import make_param_filename, make_metadata_filename
  17. from .model_implementations.inference_model_base import DSInferenceModelBase
  18. from .config_v2 import RaggedInferenceEngineConfig
  19. INFERENCE_MODEL_TIMER = "model-forward-inference"
  20. class InferenceEngineV2:
  21. _config: RaggedInferenceEngineConfig
  22. """
  23. Configuration of the inference engine.
  24. """
  25. _model: DSInferenceModelBase
  26. """
  27. Inference model supporting ragged inference.
  28. """
  29. _state_manager: DSStateManager
  30. """
  31. Persistent state manager for sequences and KV-cache.
  32. """
  33. @property
  34. def free_blocks(self) -> torch.Tensor:
  35. """
  36. Number of free KV blocks. This is a tensor of shape [n_kv_cache_groups] where each
  37. element is the number of free blocks in the corresponding KV cache group.
  38. """
  39. return self._state_manager.free_blocks
  40. @property
  41. def n_kv_cache_groups(self) -> int:
  42. """
  43. Number of KV cache groups.
  44. """
  45. return self._state_manager.n_kv_cache_groups
  46. def model(self) -> DSInferenceModelBase:
  47. """
  48. The model implementation.
  49. """
  50. return self._model
  51. def __init__(self, policy: InferenceV2Policy, engine_config: RaggedInferenceEngineConfig) -> None:
  52. """
  53. Create the Inference V2 engine.
  54. Arguments:
  55. policy (InferenceV2Policy): Policy for the model implementation. This policy object
  56. will be used to build the model and load the checkpoint associated with it.
  57. engine_config (RaggedInferenceEngineConfig): Configuration for the inference engine.
  58. """
  59. self._config = engine_config
  60. self._policy = policy
  61. self._base_mp_group = self._initialize_tp_group()
  62. # Build model from policy
  63. inference_logger().info("Building model...")
  64. self._model = self._policy.build_model(self._config, self._base_mp_group)
  65. inference_logger().info("Model built.")
  66. # Create state manager
  67. self._batch = RaggedBatchWrapper(self._config.state_manager)
  68. self._state_manager = DSStateManager(self._config.state_manager,
  69. self._model.kv_cache_config(),
  70. base_mp_group=self._base_mp_group)
  71. self._model.set_state_manager(self._state_manager)
  72. def _initialize_tp_group(self):
  73. """
  74. Implementation of our TP group initialization.
  75. """
  76. init_distributed()
  77. local_rank = int(os.getenv("LOCAL_RANK", 0))
  78. get_accelerator().set_device(local_rank)
  79. if local_rank >= self._config.tensor_parallel.tp_size:
  80. raise RuntimeError("Local rank is greater than TP size, ensure that the TP config is correct.")
  81. ranks = list(range(self._config.tensor_parallel.tp_size))
  82. return dist.new_group(ranks=ranks)
  83. def put(self,
  84. batch_uids: Iterable[int],
  85. batch_tokens: Iterable[torch.Tensor],
  86. do_checks: bool = True) -> torch.Tensor:
  87. """
  88. Put a ragged batch onto the inference engine. This will perform one forward and return
  89. a Tensor of the shape [len(batch_uids), *output_shape]. Logits for the non-final tokens
  90. are not calculated.
  91. Arguments:
  92. batch_uids: Iterable of uids for the batch on the host
  93. batch_tokens: Iterable of token tensors for the batch on the host
  94. do_checks: Check schedulability when it is set to True. You can skip this check for better performance when it has already been completed.
  95. """
  96. if do_checks:
  97. token_lens = [len(tokens) for tokens in batch_tokens]
  98. schedule_check = self.can_schedule(batch_uids, token_lens)
  99. if schedule_check != SchedulingResult.Success:
  100. raise SchedulingError(schedule_check)
  101. self._batch.clear()
  102. for uid, tokens in zip(batch_uids, batch_tokens):
  103. host_seq_desc = self._state_manager.get_or_create_sequence(uid)
  104. self._model.maybe_allocate_kv(host_seq_desc, tokens.numel())
  105. host_seq_desc.pre_forward(tokens.numel())
  106. # We can disable checks since we already validated schedulability.
  107. self._batch.insert_sequence(host_seq_desc, tokens, do_checks=do_checks)
  108. # Send all metadata to the device
  109. self._batch.finalize()
  110. # Prep all data structures for the actual forward (in anticipation of CG in the future)
  111. # and also to amortize some of the costs in a more straightforward way.
  112. self._model.prepare_batch(self._batch)
  113. # Model implementation will pick up in the forward.
  114. logits = self._model.forward(self._batch)
  115. # We return one set of logits per sequence in the batch (saves cost on unembedding)
  116. assert logits.shape[0] == self._batch.current_sequences
  117. for uid in batch_uids:
  118. host_seq_desc = self._state_manager.get_sequence(uid)
  119. host_seq_desc.post_forward() # Updates sequence metadata.
  120. self._model.maybe_free_kv(host_seq_desc)
  121. return logits
  122. def query(self, uid: int, max_request_tokens: int, max_request_blocks) -> Tuple[int, torch.Tensor]:
  123. """
  124. Determine the number of tokens and KV blocks to reserve for a given request. Given a UID
  125. (this UID may not be recognized by the model yet), this will return the number of tokens
  126. and blocks to reserve for the request.
  127. Arguments:
  128. uid (int): The UID of the sequence (as tracked by the scheduling entity). If
  129. this is a new sequence (with a UID unknown to the inference engine), then
  130. an empty placeholder is created to pass to the occupancy logic.
  131. n_tokens (int): The number of tokens to hypothetically send.
  132. Returns:
  133. Tuple[int, Optional[int]]: Tuple of free kv blocks and the number of blocks
  134. required to schedule the sequence.
  135. """
  136. seq_desc = self._state_manager.get_sequence(uid)
  137. if seq_desc is None:
  138. if (self._state_manager.n_tracked_sequences == self._config.state_manager.max_tracked_sequences):
  139. return (0, 0)
  140. seq_desc = PlaceholderSequenceDescriptor()
  141. req_tokens, req_blocks = self._model.get_kv_requirements(seq_desc, max_request_tokens, max_request_blocks)
  142. return (req_tokens, req_blocks)
  143. def can_schedule(self, uids: Iterable[int], lengths: Iterable[int]) -> SchedulingResult:
  144. """
  145. Dry run a batch to determine if it can be scheduled. Placeholder sequences will be
  146. created for any UIDs that are unknown to the inference engine.
  147. Arguments:
  148. uids (Iterable[int]): Iterable of UIDs for the batch
  149. lengths (Iterable[int]): Iterable of lengths for each sequence of the batch. This lengths
  150. corresponds to the number of tokens to send in the hypothetical forward; history
  151. tokens will be determined via UID lookup and future tokens are disregarded.
  152. Returns:
  153. bool: True if the batch can be scheduled, False otherwise.
  154. """
  155. cur_seqs = self._state_manager.n_tracked_sequences
  156. free_blocks = self._state_manager.free_blocks
  157. req_blocks = 0
  158. batch_len = 0
  159. if len(uids) > self._config.state_manager.max_ragged_sequence_count:
  160. # Can only compose a batch from a limited number of sequences
  161. return SchedulingResult.BatchSequenceLimitExceeded
  162. for uid, length in zip(uids, lengths):
  163. seq_desc = self._state_manager.get_sequence(uid)
  164. if seq_desc is None:
  165. cur_seqs += 1
  166. seq_desc = PlaceholderSequenceDescriptor()
  167. sched_len, sched_blocks = self._model.get_kv_requirements(seq_desc, length, free_blocks)
  168. if sched_len != length:
  169. # We ran out of KV cache
  170. return SchedulingResult.KVCacheLimitExceeded
  171. batch_len += length
  172. free_blocks -= sched_blocks
  173. if cur_seqs > self._config.state_manager.max_tracked_sequences:
  174. # Would run out of tracking metadata
  175. return SchedulingResult.EngineSequenceLimitExceeded
  176. if batch_len > self._config.state_manager.max_ragged_batch_size:
  177. # Would exceed the maximum batch size
  178. return SchedulingResult.BatchTokenLimitExceeded
  179. return SchedulingResult.Success
  180. def get_remaining_block_capacity(self, uid: int) -> int:
  181. """
  182. Get the remaining capacity of the last block already allocated.
  183. """
  184. seq_desc = self._state_manager.get_sequence(uid)
  185. if seq_desc is None:
  186. return 0
  187. return self._model.get_remaining_block_capacity(seq_desc)
  188. def flush(self, uid: int) -> None:
  189. """
  190. Remove all state associated with a sequence from the inference engine.
  191. Arguments:
  192. uid (int): The UID of the sequence to flush.
  193. """
  194. self._state_manager.flush_sequence(uid)
  195. def serialize(self, save_path: str) -> None:
  196. """
  197. Serialize the model to a file.
  198. Arguments:
  199. path (str): Path to the file to serialize to.
  200. """
  201. param_file_name = make_param_filename(save_path, self._model.tp_rank, self._model.tp_size)
  202. metadata_file_name = make_metadata_filename(save_path, self._model.tp_rank, self._model.tp_size)
  203. # Save the flattened parameters
  204. torch.save(self._model.flattened_params, param_file_name)
  205. json.dump(self._model.flattened_param_metadata.json(), open(metadata_file_name, "w"))
  206. if self._model.tp_rank == 0:
  207. pickle.dump(self._model._config, open(os.path.join(save_path, "ds_model_config.pkl"), "wb"))