engine.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import time
  6. import os
  7. from deepspeed import comm as dist
  8. from deepspeed.utils.logging import log_dist
  9. from torch.nn.modules import Module
  10. from packaging import version as pkg_version
  11. from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
  12. from deepspeed.utils.timer import SynchronizedWallClockTimer
  13. from ..runtime.state_dict_factory import SDLoaderFactory
  14. from ..runtime.weight_quantizer import WeightQuantization
  15. from ..module_inject import replace_transformer_layer, generic_injection
  16. from ..comm.comm import init_distributed
  17. from ..pipe import PipelineModule
  18. from ..moe.utils import has_moe_layers
  19. from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
  20. from deepspeed.accelerator import get_accelerator
  21. from ..module_inject.policy import TransformerPolicy
  22. from ..module_inject.auto_tp import AutoTP
  23. from ..module_inject.replace_policy import generic_policies
  24. from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor
  25. from ..ops.transformer.inference.ds_attention import DeepSpeedSelfAttention
  26. from ..model_implementations.transformers.ds_transformer import DeepSpeedTransformerInference
  27. DS_INFERENCE_ENABLED = False
  28. from torch import nn
  29. INFERENCE_MODEL_TIMER = "model-forward-inference"
  30. class InferenceEngine(Module):
  31. inference_mp_group = None
  32. inference_ep_group = None
  33. expert_mp_group = None
  34. def __init__(self, model, config):
  35. """
  36. Args:
  37. model: torch.nn.Module
  38. config: DeepSpeedInferenceConfig
  39. """
  40. global DS_INFERENCE_ENABLED
  41. DS_INFERENCE_ENABLED = True
  42. super().__init__()
  43. # Have to import here because inference_module is a global, but python
  44. # globals only work at the module level and will not be updated unless
  45. # we import it each time we init a new inference engine.
  46. from ..model_implementations.transformers.ds_transformer import inference_module
  47. if inference_module is not None:
  48. self.destroy()
  49. self.module = model
  50. self._config = config
  51. self._get_model_config_generate(config) # keep for weird backward compatibility
  52. # patch model generate with ours if model uses it
  53. if hasattr(self.module, "generate"):
  54. self.generate = self._generate
  55. if hasattr(self.module, "config"):
  56. TransformerPolicy.hf_model_config = self.module.config
  57. # todo: keep this self.injection_dict because we don't use to change config.injection_policy API
  58. # todo: this will get changed when Molly's PR on auto injection dict is merged
  59. self.injection_dict = config.injection_policy
  60. # todo: refactor the mp_group and mp_size related in the next refactor
  61. self.mp_group = config.tensor_parallel.tp_group
  62. self.mpu = config.tensor_parallel.mpu
  63. #self._validate_args(self.mpu, config.replace_with_kernel_inject)
  64. self.quantize_merge_count = 1
  65. self.quantization_scales = None
  66. # these are not needed in the config as we are creating them ourselves in the inference engine
  67. self.ep_group = None # config.moe.ep_group
  68. self.expert_mp_group = None # config.moe.ep_mp_group
  69. self.cuda_graph_created = False
  70. self.checkpoint_engine = TorchCheckpointEngine()
  71. quantization_setting = None
  72. self._init_quantization_setting(
  73. quantization_setting) # todo: update with the new quant config for weight quant
  74. self.model_profile_enabled = False
  75. self._model_times = []
  76. if not self.injection_dict and config.replace_with_kernel_inject:
  77. # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
  78. self.remove_mask_prepare_for_bloom()
  79. if self.injection_dict or not config.replace_with_kernel_inject:
  80. # This is a hack to redefine the alibi func due to TP
  81. if config.tensor_parallel.tp_size > 1:
  82. self.build_alibi_tensor()
  83. self.build_attn_bias()
  84. if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph:
  85. assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
  86. "If you want to use cuda graph, please upgrade torch to at least v1.10"
  87. # Check if model passed to engine is loaded w/ meta tensors, in which case
  88. # kernel injection must be enabled.
  89. # NOTE: This check assumes a Hugging Face hierarchy for the device type i.e. module.device.type
  90. self.model_meta_device = self.module.device.type == 'meta' if hasattr(self.module, "device") else False
  91. # convert model to intended dtype
  92. if config.dtype:
  93. self._convert_to_dtype(config)
  94. if self.mpu:
  95. config.tensor_parallel.tp_size = dist.get_world_size(group=self.mpu.get_model_parallel_group())
  96. self.mp_group = self.mpu.get_model_parallel_group()
  97. elif config.tensor_parallel.tp_size > 1:
  98. self._create_model_parallel_group(config)
  99. config.tensor_parallel.tp_group = self.mp_group
  100. if isinstance(self.module, torch.nn.Module):
  101. moe, _ = has_moe_layers(self.module)
  102. else:
  103. moe = False
  104. if moe and dist.get_world_size() > 1:
  105. self._create_ep_parallel_group(config.moe.moe_experts)
  106. # We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism if tp_size > 1.
  107. if self.injection_dict:
  108. # 1. User specified Tensor Parallelism
  109. assert not config.replace_with_kernel_inject, "Cannot use both user specified injection policy and kernel injection"
  110. for client_module, injection_policy in self.injection_dict.items():
  111. assert issubclass(client_module,
  112. torch.nn.Module), f"{client_module} is not a subclass of torch.nn.Module"
  113. # construct the tuple and pass that instead of a string or dict.
  114. if isinstance(injection_policy, str):
  115. config.injection_policy_tuple = (injection_policy, )
  116. else:
  117. config.injection_policy_tuple = injection_policy
  118. layer_names = [name for name, _ in self.module.named_modules()]
  119. for policy in config.injection_policy_tuple:
  120. if not any(name.endswith(policy) for name in layer_names):
  121. raise ValueError(f"Injection policy layer'{policy}' not valid.")
  122. self._apply_injection_policy(config, client_module)
  123. else:
  124. if config.replace_with_kernel_inject:
  125. # 2. DeepSpeed Kernel Injection
  126. self._apply_injection_policy(config)
  127. elif config.tensor_parallel.tp_size > 1:
  128. # 3. Automatic Tensor Parallelism
  129. parser_dict = AutoTP.tp_parser(model)
  130. print("AutoTP: ", parser_dict)
  131. for client_module, injection_policy in parser_dict:
  132. if isinstance(injection_policy, str):
  133. config.injection_policy_tuple = (injection_policy, )
  134. else:
  135. config.injection_policy_tuple = injection_policy
  136. self._apply_injection_policy(config, client_module)
  137. device = get_accelerator().current_device_name()
  138. self.module.to(device)
  139. if config.tensor_parallel.tp_size > 1:
  140. _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
  141. dist.broadcast(_rng_state, 0)
  142. get_accelerator().set_rng_state(_rng_state.cpu())
  143. if config.tensor_parallel.tp_size > 1:
  144. assert not config.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
  145. # Check if local CUDA graphs can be created in replacement modules
  146. self.local_cuda_graph = self._local_cuda_graph_used(self.module)
  147. def destroy(self):
  148. # Have to import here because inference_module is a global, but python
  149. # globals only work at the module level and will not be updated unless
  150. # we import it each time we init a new inference engine.
  151. from ..model_implementations.transformers.ds_transformer import inference_module
  152. DeepSpeedTransformerInference.layer_id = 0
  153. DeepSpeedSelfAttention.num_layers = 0
  154. if inference_module is not None:
  155. inference_module.release_workspace()
  156. inference_module = None
  157. def profile_model_time(self, use_cuda_events=True):
  158. if not self.model_profile_enabled and not self._config.enable_cuda_graph:
  159. self.module.register_forward_pre_hook(self._pre_forward_hook)
  160. self.module.register_forward_hook(self._post_forward_hook)
  161. self.model_profile_enabled = True
  162. self.use_cuda_events = use_cuda_events
  163. if self.use_cuda_events:
  164. self.timers = SynchronizedWallClockTimer()
  165. # todo: remove this once all the config dicts are centralized from top level pydantic config
  166. def _get_model_config_generate(self, config):
  167. # this is being passed to replace_transformer_layer(config=self.user_model_config_dict)
  168. self.config = getattr(self.module, 'config', None) if config.config is None else config.config
  169. def remove_mask_prepare_for_bloom(self):
  170. if hasattr(self.module, 'transformer'):
  171. if hasattr(self.module.transformer, '_prepare_attn_mask'):
  172. self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
  173. def build_alibi_tensor(self):
  174. if hasattr(self.module, 'transformer'):
  175. if hasattr(self.module.transformer, 'build_alibi_tensor'):
  176. self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor
  177. if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'):
  178. self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor
  179. self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor
  180. def build_attn_bias(self):
  181. if hasattr(self.module, 'transformer'):
  182. if hasattr(self.module.transformer, '_attn_bias'):
  183. self.module.transformer._attn_bias_orig = self.module.transformer._attn_bias
  184. self.module.transformer.__class__._attn_bias = build_mpt_atten_bias_tensor
  185. def _pre_forward_hook(self, module, *inputs, **kwargs):
  186. if self.use_cuda_events:
  187. self.timers(INFERENCE_MODEL_TIMER).start()
  188. else:
  189. get_accelerator().synchronize()
  190. self._start = time.time()
  191. def _post_forward_hook(self, module, input, output):
  192. if self.use_cuda_events:
  193. self.timers(INFERENCE_MODEL_TIMER).stop()
  194. elapsed_time = self.timers(INFERENCE_MODEL_TIMER).elapsed(reset=True)
  195. else:
  196. get_accelerator().synchronize()
  197. self._end = time.time()
  198. elapsed_time = (self._end - self._start) * 1e3 # convert seconds to ms
  199. self._model_times.append(elapsed_time)
  200. def _create_model_parallel_group(self, config):
  201. # Call the init process
  202. if InferenceEngine.inference_mp_group is None:
  203. init_distributed()
  204. local_rank = int(os.getenv('LOCAL_RANK', '0'))
  205. get_accelerator().set_device(local_rank)
  206. ranks = [i for i in range(config.tensor_parallel.tp_size)]
  207. self.mp_group = dist.new_group(ranks)
  208. InferenceEngine.inference_mp_group = self.mp_group
  209. else:
  210. self.mp_group = InferenceEngine.inference_mp_group
  211. def _create_ep_parallel_group(self, moe_experts):
  212. # Call the init process
  213. self.ep_group = {}
  214. self.expert_mp_group = {}
  215. moe_experts = moe_experts if type(moe_experts) is list else [moe_experts]
  216. for e in moe_experts:
  217. self.ep_group.update({e: None})
  218. self.expert_mp_group.update({e: None})
  219. for moe_ep_size in self.ep_group.keys():
  220. num_ep_groups = dist.get_world_size() // moe_ep_size
  221. for i in range(num_ep_groups):
  222. ep_cnt = i * moe_ep_size
  223. size = dist.get_world_size() if moe_ep_size > dist.get_world_size() else moe_ep_size
  224. ranks = list(range(ep_cnt, ep_cnt + size))
  225. _ep_group = dist.new_group(ranks)
  226. if dist.get_rank() in ranks:
  227. self.ep_group.update({moe_ep_size: _ep_group})
  228. if dist.get_world_size() > moe_ep_size:
  229. num_expert_mp_groups = dist.get_world_size() // num_ep_groups
  230. expert_mp_size = dist.get_world_size() // moe_ep_size
  231. for i in range(num_expert_mp_groups):
  232. expert_mp_comm_ranks = [i + nr * moe_ep_size for nr in range(expert_mp_size)]
  233. _expert_mp_group = dist.new_group(expert_mp_comm_ranks)
  234. if dist.get_rank() in expert_mp_comm_ranks:
  235. self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
  236. def _init_quantization_setting(self, quantization_setting):
  237. self.quantize_bits = 8
  238. self.mlp_extra_grouping = False
  239. self.quantize_groups = 1
  240. if type(quantization_setting) is tuple:
  241. self.mlp_extra_grouping, \
  242. self.quantize_groups = quantization_setting
  243. elif quantization_setting is not None:
  244. self.quantize_groups = quantization_setting
  245. log_dist(
  246. f"quantize_bits = {self.quantize_bits} "
  247. f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
  248. f"quantize_groups = {self.quantize_groups}", [0])
  249. # TODO: remove this function and add this functionality to pydantic config checking
  250. def _validate_args(self, mpu, replace_with_kernel_inject):
  251. # TODO: to support SD pipeline we need to avoid this check for now
  252. if replace_with_kernel_inject and not isinstance(self.module, Module):
  253. raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
  254. if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
  255. raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")
  256. if mpu:
  257. methods = ["get_model_parallel_group", "get_data_parallel_group"]
  258. for method in methods:
  259. if not hasattr(mpu, method):
  260. raise ValueError(f"mpu is missing {method}")
  261. if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
  262. raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")
  263. supported_dtypes = [None, torch.half, torch.int8, torch.float]
  264. if self._config.dtype not in supported_dtypes:
  265. raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
  266. if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
  267. raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")
  268. def load_model_with_checkpoint(self, r_module):
  269. self.mp_replace = ReplaceWithTensorSlicing(
  270. mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
  271. error_msgs = []
  272. def load(module, state_dict, prefix):
  273. args = (state_dict, prefix, {}, True, [], [], error_msgs)
  274. if hasattr(module, 'weight'):
  275. if module.weight.data.is_meta:
  276. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  277. module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data,
  278. device="cpu"),
  279. requires_grad=module.weight.data.requires_grad)
  280. if 'query_key_value' in prefix:
  281. module.weight = self.mp_replace.strided_copy(module.weight.data,
  282. state_dict[prefix + 'weight'],
  283. num_splits=3)
  284. else:
  285. module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
  286. else:
  287. if module.norm.weight.data.is_meta:
  288. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  289. module.norm.weight = torch.nn.parameter.Parameter(
  290. data=torch.empty_like(module.norm.weight.data, device="cpu"),
  291. requires_grad=module.norm.weight.data.requires_grad)
  292. module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
  293. if prefix + 'bias' in self.key_list:
  294. if hasattr(module, 'norm'):
  295. if module.norm.bias.data.is_meta:
  296. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  297. module.norm.bias = torch.nn.parameter.Parameter(
  298. data=torch.empty_like(module.norm.bias.data, device="cpu"),
  299. requires_grad=module.norm.bias.data.requires_grad)
  300. module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
  301. else:
  302. if module.bias.data.is_meta:
  303. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  304. module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data,
  305. device="cpu"),
  306. requires_grad=module.bias.data.requires_grad)
  307. data = state_dict[prefix + 'bias']
  308. data = data.to(get_accelerator().current_device_name())
  309. module.bias = self.mp_replace.copy(module.bias, data)
  310. layer_policies = {
  311. nn.Linear: load,
  312. nn.Embedding: load,
  313. nn.LayerNorm: load,
  314. LinearLayer: load,
  315. LinearAllreduce: load
  316. }
  317. def load_module_recursive(module, prefix='', level=0):
  318. for name, child in module.named_children():
  319. if child.__class__ in layer_policies:
  320. checking_key = prefix + name + '.'
  321. if not any(checking_key in item for item in self.key_list):
  322. continue
  323. if len(list(child.parameters())) > 0 and list(child.parameters())[0].numel() == 0:
  324. if len(child.weight.ds_shape) == 1:
  325. child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
  326. setattr(module, name, child)
  327. load(child, self.sd, prefix + name + '.')
  328. else:
  329. load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1)
  330. load_module_recursive(r_module)
  331. embedding_weight = None
  332. for n, p in r_module.named_parameters():
  333. if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
  334. embedding_weight = p
  335. if embedding_weight is not None and hasattr(r_module, "lm_head") and hasattr(
  336. r_module.lm_head, "weight") and r_module.lm_head.weight.is_meta:
  337. r_module.lm_head.weight = embedding_weight
  338. def _apply_injection_policy(self, config, client_module=None):
  339. # client_module is only passed when using the injection_dict method.
  340. checkpoint_dir = config.checkpoint
  341. checkpoint = SDLoaderFactory.get_sd_loader_json(checkpoint_dir,
  342. self.checkpoint_engine) if checkpoint_dir is not None else None
  343. generic_injection(self.module, dtype=config.dtype, enable_cuda_graph=config.enable_cuda_graph)
  344. if isinstance(self.module, torch.nn.Module):
  345. # config is our DeepSpeedInferenceConfig and self.config is the HF model config
  346. replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
  347. def _get_all_ckpt_names(self, checkpoints_path, tag):
  348. ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
  349. import glob
  350. ckpt_files = glob.glob(ckpt_file_pattern)
  351. ckpt_files.sort()
  352. return ckpt_files
  353. def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
  354. if mp_placeholder is not None:
  355. mp_rank_str = mp_placeholder
  356. else:
  357. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  358. mp_rank_str = "{:02d}".format(mp_rank)
  359. ckpt_name = os.path.join(
  360. checkpoints_path,
  361. "mp_rank_" + mp_rank_str + "_model_states.pt",
  362. )
  363. return ckpt_name
  364. def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
  365. is_pipe_parallel = isinstance(self.module, PipelineModule)
  366. if is_pipe_parallel:
  367. raise RuntimeError('pipeline parallelism is currently not supported in inference.')
  368. if not isinstance(load_dir, dict) and os.path.isdir(load_dir):
  369. if tag is None:
  370. latest_path = os.path.join(load_dir, "latest")
  371. if os.path.isfile(latest_path):
  372. with open(latest_path, "r") as fd:
  373. tag = fd.read().strip()
  374. ckpt_list = self._get_all_ckpt_names(load_dir, tag)
  375. sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
  376. else:
  377. sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine)
  378. checkpoint = sd_loader['checkpoints']
  379. if type(checkpoint) is list:
  380. self.sd = torch.load(checkpoint[0], map_location='cpu')
  381. self.key_list = list(self.sd.keys())
  382. self.load_model_with_checkpoint(self.module)
  383. for i in range(1, len(checkpoint)):
  384. if not dist.is_initialized() or dist.get_rank() == 0:
  385. print(f"loading checkpoint ({i})")
  386. self.sd = torch.load(checkpoint[i], map_location=get_accelerator().device_name())
  387. self.key_list = list(self.sd.keys())
  388. self.load_model_with_checkpoint(self.module)
  389. else:
  390. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  391. load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size,
  392. mp_rank,
  393. is_pipe_parallel=is_pipe_parallel,
  394. quantize=(self._config.dtype is torch.int8),
  395. quantize_groups=self.quantize_groups,
  396. mlp_extra_grouping=self.mlp_extra_grouping)
  397. self.quantization_scales, self.quantize_merge_count = quantize_config
  398. moe, _ = has_moe_layers(self.module)
  399. if moe:
  400. from deepspeed.runtime.engine import DeepSpeedEngine
  401. old_moe_load = False
  402. if not isinstance(checkpoint['num_experts'], list):
  403. old_moe_load = True
  404. DeepSpeedEngine.load_moe_state_dict(load_dir,
  405. tag,
  406. state_dict=checkpoint[self._choose_module_key(checkpoint)],
  407. old_moe_load=old_moe_load,
  408. model=self.module,
  409. mpu=self.mpu,
  410. checkpoint_engine=self.checkpoint_engine)
  411. self.module.load_state_dict(state_dict=checkpoint[self._choose_module_key(checkpoint)],
  412. strict=load_module_strict)
  413. def _choose_module_key(self, sd):
  414. assert not ('module' in sd
  415. and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
  416. assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
  417. if 'module' in sd:
  418. return 'module'
  419. elif 'model' in sd:
  420. return 'model'
  421. def _convert_to_dtype(self, config):
  422. if not isinstance(self.module, torch.nn.Module):
  423. return
  424. if False: #config.dtype is torch.int8 and self.quantization_scales is None:
  425. quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
  426. model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict,
  427. self.quantize_bits, self.quantize_groups)
  428. elif config.dtype == torch.half:
  429. self.module.half()
  430. elif config.dtype == torch.bfloat16:
  431. self.module.bfloat16()
  432. elif config.dtype == torch.float:
  433. self.module.float()
  434. def _create_cuda_graph(self, *inputs, **kwargs):
  435. # warmup to create the workspace and cublas handle
  436. cuda_stream = get_accelerator().Stream()
  437. cuda_stream.wait_stream(get_accelerator().current_stream())
  438. with get_accelerator().stream(cuda_stream):
  439. for i in range(3):
  440. ret = self.module(*inputs, **kwargs)
  441. get_accelerator().current_stream().wait_stream(cuda_stream)
  442. # create cuda_graph and assign static_inputs and static_outputs
  443. self._cuda_graphs = torch.cuda.CUDAGraph()
  444. self.static_inputs = inputs
  445. self.static_kwargs = kwargs
  446. with torch.cuda.graph(self._cuda_graphs):
  447. self.static_output = self.module(*self.static_inputs, **self.static_kwargs)
  448. self.cuda_graph_created = True
  449. def _graph_replay(self, *inputs, **kwargs):
  450. for i in range(len(inputs)):
  451. if torch.is_tensor(inputs[i]):
  452. self.static_inputs[i].copy_(inputs[i])
  453. for k in kwargs:
  454. if torch.is_tensor(kwargs[k]):
  455. self.static_kwargs[k].copy_(kwargs[k])
  456. self._cuda_graphs.replay()
  457. return self.static_output
  458. def model_times(self):
  459. assert self.model_profile_enabled, "model profiling is not enabled"
  460. model_times = self._model_times
  461. if self._config.enable_cuda_graph and len(self._model_times) == 0:
  462. raise ValueError("Model times are empty and cuda graph is enabled. If "
  463. "this is a GPT-style model this combo is not supported. If this is a "
  464. "BERT-style model this is a bug, please report it. "
  465. f"Model type is: {type(self.module)}")
  466. self._model_times = []
  467. return model_times
  468. def _module_match(self, module):
  469. for policy in generic_policies:
  470. policy = policy()
  471. if policy.match_replaced(module):
  472. return True
  473. return False
  474. def _local_cuda_graph_used(self, module):
  475. if isinstance(module, torch.nn.Module):
  476. return False
  477. else:
  478. sub_module_cuda_graph = False
  479. for name in module.__dict__.keys():
  480. sub_module = getattr(module, name)
  481. if self._module_match(sub_module) and hasattr(sub_module, "enable_cuda_graph"):
  482. sub_module_cuda_graph = True
  483. return sub_module_cuda_graph
  484. def forward(self, *inputs, **kwargs):
  485. """Execute forward propagation
  486. Arguments:
  487. *inputs: Variable length input list
  488. **kwargs: variable length keyword arguments
  489. """
  490. start = None
  491. if self.model_profile_enabled and get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph:
  492. get_accelerator().synchronize()
  493. start = time.time()
  494. if get_accelerator().device_name() == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
  495. if self.cuda_graph_created:
  496. outputs = self._graph_replay(*inputs, **kwargs)
  497. else:
  498. self._create_cuda_graph(*inputs, **kwargs)
  499. outputs = self._graph_replay(*inputs, **kwargs)
  500. else:
  501. outputs = self.module(*inputs, **kwargs)
  502. if self.model_profile_enabled and self._config.enable_cuda_graph:
  503. get_accelerator().synchronize()
  504. duration = (time.time() - start) * 1e3 # convert seconds to ms
  505. self._model_times.append(duration)
  506. return outputs
  507. def _generate(self, *inputs, **kwargs):
  508. # Reset KV-cache at the beginning of generate
  509. if hasattr(self.module, 'reset_cache'):
  510. self.module.reset_cache()
  511. num_beams = 1
  512. if "generation_config" in kwargs:
  513. gen_config = kwargs["generation_config"]
  514. num_beams = getattr(gen_config, "num_beams", 1)
  515. if "num_beams" in kwargs:
  516. num_beams = kwargs["num_beams"]
  517. if num_beams > 1:
  518. raise NotImplementedError("DeepSpeed does not support `num_beams` > 1, if this is important to you please "
  519. "add your request to: https://github.com/microsoft/DeepSpeed/issues/2506")
  520. if ("input_ids" in kwargs) and (kwargs["input_ids"].dim() == 2):
  521. for input_tensor in kwargs["input_ids"]:
  522. tensor_length = input_tensor.shape[-1]
  523. if tensor_length > self._config.max_out_tokens:
  524. raise RuntimeError(
  525. f"Input with size {tensor_length} exceeds maximum length of {self._config.max_out_tokens}. Please increase `max_tokens` in the DeepSpeed Inference Config."
  526. )
  527. return self.module.generate(*inputs, **kwargs)