engine.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. import time
  6. import os
  7. from deepspeed import comm as dist
  8. from deepspeed.utils.logging import log_dist
  9. from torch.nn.modules import Module
  10. from packaging import version as pkg_version
  11. from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
  12. from deepspeed.utils.timer import SynchronizedWallClockTimer
  13. from ..runtime.state_dict_factory import SDLoaderFactory
  14. from ..runtime.weight_quantizer import WeightQuantization
  15. from ..module_inject import replace_transformer_layer, generic_injection
  16. from ..comm.comm import init_distributed
  17. from ..pipe import PipelineModule
  18. from ..moe.utils import has_moe_layers
  19. from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
  20. from deepspeed.accelerator import get_accelerator
  21. from ..module_inject.policy import TransformerPolicy
  22. from ..module_inject.auto_tp import AutoTP
  23. from ..module_inject.replace_policy import generic_policies
  24. DS_INFERENCE_ENABLED = False
  25. from torch import nn
  26. INFERENCE_MODEL_TIMER = "model-forward-inference"
  27. class InferenceEngine(Module):
  28. inference_mp_group = None
  29. inference_ep_group = None
  30. expert_mp_group = None
  31. def __init__(self, model, config):
  32. """
  33. Args:
  34. model: torch.nn.Module
  35. config: DeepSpeedInferenceConfig
  36. """
  37. global DS_INFERENCE_ENABLED
  38. DS_INFERENCE_ENABLED = True
  39. super().__init__()
  40. self.module = model
  41. self._config = config
  42. self._get_model_config_generate(config) # keep for weird backward compatibility
  43. # patch model generate with ours if model uses it
  44. if hasattr(self.module, "generate"):
  45. self.generate = self._generate
  46. if hasattr(self.module, "config"):
  47. TransformerPolicy.hf_model_config = self.module.config
  48. # todo: keep this self.injection_dict because we don't use to change config.injection_policy API
  49. # todo: this will get changed when Molly's PR on auto injection dict is merged
  50. self.injection_dict = config.injection_policy
  51. # todo: refactor the mp_group and mp_size related in the next refactor
  52. self.mp_group = config.tensor_parallel.tp_group
  53. self.mpu = config.tensor_parallel.mpu
  54. #self._validate_args(self.mpu, config.replace_with_kernel_inject)
  55. self.quantize_merge_count = 1
  56. self.quantization_scales = None
  57. # these are not needed in the config as we are creating them ourselves in the inference engine
  58. self.ep_group = None # config.moe.ep_group
  59. self.expert_mp_group = None # config.moe.ep_mp_group
  60. self.cuda_graph_created = False
  61. self.checkpoint_engine = TorchCheckpointEngine()
  62. quantization_setting = None
  63. self._init_quantization_setting(
  64. quantization_setting
  65. ) # todo: update with the new quant config for weight quant
  66. self.model_profile_enabled = False
  67. self._model_times = []
  68. # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
  69. self.remove_mask_prepare_for_bloom()
  70. if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph:
  71. assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \
  72. "If you want to use cuda graph, please upgrade torch to at least v1.10"
  73. if config.checkpoint and not config.replace_with_kernel_inject:
  74. self._load_checkpoint(config.checkpoint)
  75. # convert model to intended dtype
  76. if config.dtype:
  77. self._convert_to_dtype(config)
  78. if self.mpu:
  79. config.tensor_parallel.tp_size = dist.get_world_size(
  80. group=self.mpu.get_model_parallel_group())
  81. self.mp_group = self.mpu.get_model_parallel_group()
  82. elif config.tensor_parallel.tp_size > 1:
  83. self._create_model_parallel_group(config)
  84. config.tensor_parallel.tp_group = self.mp_group
  85. if isinstance(self.module, torch.nn.Module):
  86. moe, _ = has_moe_layers(self.module)
  87. else:
  88. moe = False
  89. if moe and dist.get_world_size() > 1:
  90. self._create_ep_parallel_group(config.moe.moe_experts)
  91. # retain this from the old conditional argument being passed to apply_injection_policy()
  92. if not config.replace_with_kernel_inject:
  93. config.checkpoint = None
  94. # We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism.
  95. if self.injection_dict:
  96. # 1. User specified Tensor Parallelism
  97. assert not config.replace_with_kernel_inject, "Cannot use both user specified injection policy and kernel injection"
  98. for client_module, injection_policy in self.injection_dict.items():
  99. # construct the tuple and pass that instead of a string or dict.
  100. if isinstance(injection_policy, str):
  101. config.injection_policy_tuple = (injection_policy, )
  102. else:
  103. config.injection_policy_tuple = injection_policy
  104. self._apply_injection_policy(config, client_module)
  105. else:
  106. if config.replace_with_kernel_inject:
  107. # 2. DeepSpeed Kernel Injection
  108. self._apply_injection_policy(config)
  109. else:
  110. # 3. Automatic Tensor Parallelism
  111. parser_dict = AutoTP.tp_parser(model)
  112. print("AutoTP: ", parser_dict)
  113. for client_module, injection_policy in parser_dict:
  114. if isinstance(injection_policy, str):
  115. config.injection_policy_tuple = (injection_policy, )
  116. else:
  117. config.injection_policy_tuple = injection_policy
  118. self._apply_injection_policy(config, client_module)
  119. device = get_accelerator().current_device_name()
  120. self.module.to(device)
  121. if config.tensor_parallel.tp_size > 1:
  122. _rng_state = get_accelerator().get_rng_state().to(
  123. get_accelerator().current_device_name())
  124. dist.broadcast(_rng_state, 0)
  125. get_accelerator().set_rng_state(_rng_state.cpu())
  126. if config.tensor_parallel.tp_size > 1:
  127. assert not config.enable_cuda_graph, "Cuda graph is not supported for model parallelism"
  128. # Check if local CUDA graphs can be created in replacement modules
  129. self.local_cuda_graph = self._local_cuda_graph_used(self.module)
  130. def profile_model_time(self, use_cuda_events=True):
  131. if not self.model_profile_enabled and not self._config.enable_cuda_graph:
  132. self.module.register_forward_pre_hook(self._pre_forward_hook)
  133. self.module.register_forward_hook(self._post_forward_hook)
  134. self.model_profile_enabled = True
  135. self.use_cuda_events = use_cuda_events
  136. if self.use_cuda_events:
  137. self.timers = SynchronizedWallClockTimer()
  138. # todo: remove this once all the config dicts are centralized from top level pydantic config
  139. def _get_model_config_generate(self, config):
  140. # this is being passed to replace_transformer_layer(config=self.user_model_config_dict)
  141. self.config = getattr(self.module,
  142. 'config',
  143. None) if config.config is None else config.config
  144. def remove_mask_prepare_for_bloom(self):
  145. if hasattr(self.module, 'transformer'):
  146. if hasattr(self.module.transformer, '_prepare_attn_mask'):
  147. self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask
  148. def _pre_forward_hook(self, module, *inputs, **kwargs):
  149. if self.use_cuda_events:
  150. self.timers(INFERENCE_MODEL_TIMER).start()
  151. else:
  152. get_accelerator().synchronize()
  153. self._start = time.time()
  154. def _post_forward_hook(self, module, input, output):
  155. if self.use_cuda_events:
  156. self.timers(INFERENCE_MODEL_TIMER).stop()
  157. elapsed_time = self.timers(INFERENCE_MODEL_TIMER).elapsed(reset=True)
  158. else:
  159. get_accelerator().synchronize()
  160. self._end = time.time()
  161. elapsed_time = self._end - self._start
  162. self._model_times.append(elapsed_time)
  163. def _create_model_parallel_group(self, config):
  164. # Call the init process
  165. if InferenceEngine.inference_mp_group is None:
  166. init_distributed()
  167. local_rank = int(os.getenv('LOCAL_RANK', '0'))
  168. get_accelerator().set_device(local_rank)
  169. ranks = [i for i in range(config.tensor_parallel.tp_size)]
  170. self.mp_group = dist.new_group(ranks)
  171. InferenceEngine.inference_mp_group = self.mp_group
  172. else:
  173. self.mp_group = InferenceEngine.inference_mp_group
  174. def _create_ep_parallel_group(self, moe_experts):
  175. # Call the init process
  176. self.ep_group = {}
  177. self.expert_mp_group = {}
  178. moe_experts = moe_experts if type(moe_experts) is list else [moe_experts]
  179. for e in moe_experts:
  180. self.ep_group.update({e: None})
  181. self.expert_mp_group.update({e: None})
  182. for moe_ep_size in self.ep_group.keys():
  183. num_ep_groups = dist.get_world_size() // moe_ep_size
  184. for i in range(num_ep_groups):
  185. ep_cnt = i * moe_ep_size
  186. size = dist.get_world_size(
  187. ) if moe_ep_size > dist.get_world_size() else moe_ep_size
  188. ranks = list(range(ep_cnt, ep_cnt + size))
  189. _ep_group = dist.new_group(ranks)
  190. if dist.get_rank() in ranks:
  191. self.ep_group.update({moe_ep_size: _ep_group})
  192. if dist.get_world_size() > moe_ep_size:
  193. num_expert_mp_groups = dist.get_world_size() // num_ep_groups
  194. expert_mp_size = dist.get_world_size() // moe_ep_size
  195. for i in range(num_expert_mp_groups):
  196. expert_mp_comm_ranks = [
  197. i + nr * moe_ep_size for nr in range(expert_mp_size)
  198. ]
  199. _expert_mp_group = dist.new_group(expert_mp_comm_ranks)
  200. if dist.get_rank() in expert_mp_comm_ranks:
  201. self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
  202. def _init_quantization_setting(self, quantization_setting):
  203. self.quantize_bits = 8
  204. self.mlp_extra_grouping = False
  205. self.quantize_groups = 1
  206. if type(quantization_setting) is tuple:
  207. self.mlp_extra_grouping, \
  208. self.quantize_groups = quantization_setting
  209. elif quantization_setting is not None:
  210. self.quantize_groups = quantization_setting
  211. log_dist(
  212. f"quantize_bits = {self.quantize_bits} "
  213. f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
  214. f"quantize_groups = {self.quantize_groups}",
  215. [0])
  216. # TODO: remove this function and add this functionality to pydantic config checking
  217. def _validate_args(self, mpu, replace_with_kernel_inject):
  218. # TODO: to support SD pipeline we need to avoid this check for now
  219. if replace_with_kernel_inject and not isinstance(self.module, Module):
  220. raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
  221. if not isinstance(self._config.tensor_parallel.tp_size,
  222. int) or self._config.tensor_parallel.tp_size < 1:
  223. raise ValueError(
  224. f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}"
  225. )
  226. if mpu:
  227. methods = ["get_model_parallel_group", "get_data_parallel_group"]
  228. for method in methods:
  229. if not hasattr(mpu, method):
  230. raise ValueError(f"mpu is missing {method}")
  231. if self._config.checkpoint is not None and not isinstance(
  232. self._config.checkpoint,
  233. (str,
  234. dict)):
  235. raise ValueError(
  236. f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}"
  237. )
  238. supported_dtypes = [None, torch.half, torch.int8, torch.float]
  239. if self._config.dtype not in supported_dtypes:
  240. raise ValueError(
  241. f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")
  242. if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
  243. raise ValueError(
  244. f"injection_dict must be None or a dict, got: {self.injection_dict}")
  245. def load_model_with_checkpoint(self, r_module):
  246. self.mp_replace = ReplaceWithTensorSlicing(
  247. mp_group=self.mp_group,
  248. mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
  249. error_msgs = []
  250. def load(module, state_dict, prefix):
  251. args = (state_dict, prefix, {}, True, [], [], error_msgs)
  252. if hasattr(module, 'weight'):
  253. if 'query_key_value' in prefix:
  254. module.weight = self.mp_replace.qkv_copy(
  255. module.weight.data,
  256. state_dict[prefix + 'weight'])
  257. else:
  258. module.weight = self.mp_replace.copy(module.weight.data,
  259. state_dict[prefix + 'weight'])
  260. else:
  261. module.norm.weight = self.mp_replace.copy(module.norm.weight.data,
  262. state_dict[prefix + 'weight'])
  263. if prefix + 'bias' in self.key_list:
  264. if hasattr(module, 'norm'):
  265. module.norm.bias = self.mp_replace.copy(module.norm.bias,
  266. state_dict[prefix + 'bias'])
  267. else:
  268. data = state_dict[prefix + 'bias']
  269. data = data.to(get_accelerator().current_device_name())
  270. module.bias = self.mp_replace.copy(module.bias, data)
  271. layer_policies = {
  272. nn.Linear: load,
  273. nn.Embedding: load,
  274. nn.LayerNorm: load,
  275. LinearLayer: load,
  276. LinearAllreduce: load
  277. }
  278. def load_module_recursive(module, prefix='', level=0):
  279. for name, child in module.named_children():
  280. if child.__class__ in layer_policies:
  281. checking_key = prefix + name + '.'
  282. if not any(checking_key in item for item in self.key_list):
  283. continue
  284. if len(list(child.parameters())) > 0 and list(
  285. child.parameters())[0].numel() == 0:
  286. if len(child.weight.ds_shape) == 1:
  287. child = Normalize(dim=child.weight.ds_shape[-1],
  288. dtype=child.weight.dtype,
  289. eps=child.eps)
  290. setattr(module, name, child)
  291. load(child, self.sd, prefix + name + '.')
  292. else:
  293. load_module_recursive(child,
  294. prefix if level == 0 else prefix + name + '.',
  295. level + 1)
  296. load_module_recursive(r_module)
  297. def _apply_injection_policy(self, config, client_module=None):
  298. # client_module is only passed when using the injection_dict method.
  299. checkpoint_dir = config.checkpoint
  300. checkpoint = SDLoaderFactory.get_sd_loader_json(
  301. checkpoint_dir,
  302. self.checkpoint_engine) if checkpoint_dir is not None else None
  303. generic_injection(self.module,
  304. fp16=(config.dtype == torch.half)
  305. or (config.dtype == torch.int8),
  306. enable_cuda_graph=config.enable_cuda_graph)
  307. if isinstance(self.module, torch.nn.Module):
  308. # config is our DeepSpeedInferenceConfig and self.config is the HF model config
  309. replace_transformer_layer(client_module,
  310. self.module,
  311. checkpoint,
  312. config,
  313. self.config)
  314. def _get_all_ckpt_names(self, checkpoints_path, tag):
  315. ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
  316. tag,
  317. mp_placeholder="*")
  318. import glob
  319. ckpt_files = glob.glob(ckpt_file_pattern)
  320. ckpt_files.sort()
  321. return ckpt_files
  322. def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
  323. if mp_placeholder is not None:
  324. mp_rank_str = mp_placeholder
  325. else:
  326. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  327. mp_rank_str = "{:02d}".format(mp_rank)
  328. ckpt_name = os.path.join(
  329. checkpoints_path,
  330. "mp_rank_" + mp_rank_str + "_model_states.pt",
  331. )
  332. return ckpt_name
  333. def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
  334. is_pipe_parallel = isinstance(self.module, PipelineModule)
  335. if is_pipe_parallel:
  336. raise RuntimeError(
  337. 'pipeline parallelism is currently not supported in inference.')
  338. if not isinstance(load_dir, dict) and os.path.isdir(load_dir):
  339. if tag is None:
  340. latest_path = os.path.join(load_dir, "latest")
  341. if os.path.isfile(latest_path):
  342. with open(latest_path, "r") as fd:
  343. tag = fd.read().strip()
  344. ckpt_list = self._get_all_ckpt_names(load_dir, tag)
  345. sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
  346. else:
  347. sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir,
  348. self.checkpoint_engine)
  349. if type(sd_loader) is list:
  350. self.sd = torch.load(sd_loader[0], map_location='cpu')
  351. self.key_list = list(self.sd.keys())
  352. self.load_model_with_checkpoint(self.module)
  353. for i in range(1, len(sd_loader)):
  354. if not dist.is_initialized() or dist.get_rank() == 0:
  355. print(f"loading checkpoint ({i})")
  356. self.sd = torch.load(sd_loader[i],
  357. map_location=get_accelerator().device_name())
  358. self.key_list = list(self.sd.keys())
  359. self.load_model_with_checkpoint(self.module)
  360. else:
  361. mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
  362. load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size,
  363. mp_rank,
  364. is_pipe_parallel=is_pipe_parallel,
  365. quantize=(self._config.dtype is torch.int8),
  366. quantize_groups=self.quantize_groups,
  367. mlp_extra_grouping=self.mlp_extra_grouping)
  368. self.quantization_scales, self.quantize_merge_count = quantize_config
  369. moe, _ = has_moe_layers(self.module)
  370. if moe:
  371. from deepspeed.runtime.engine import DeepSpeedEngine
  372. old_moe_load = False
  373. if not isinstance(checkpoint['num_experts'], list):
  374. old_moe_load = True
  375. DeepSpeedEngine.load_moe_state_dict(
  376. load_dir,
  377. tag,
  378. state_dict=checkpoint[self._choose_module_key(checkpoint)],
  379. old_moe_load=old_moe_load,
  380. model=self.module,
  381. mpu=self.mpu,
  382. checkpoint_engine=self.checkpoint_engine)
  383. self.module.load_state_dict(
  384. state_dict=checkpoint[self._choose_module_key(checkpoint)],
  385. strict=load_module_strict)
  386. def _choose_module_key(self, sd):
  387. assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed"
  388. assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed"
  389. if 'module' in sd:
  390. return 'module'
  391. elif 'model' in sd:
  392. return 'model'
  393. def _convert_to_dtype(self, config):
  394. if not isinstance(self.module, torch.nn.Module):
  395. return
  396. if False: #config.dtype is torch.int8 and self.quantization_scales is None:
  397. quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
  398. model, self.quantization_scales = quantizer.model_quantize(self.module,
  399. self.injection_dict,
  400. self.quantize_bits,
  401. self.quantize_groups)
  402. elif config.dtype == torch.half:
  403. self.module.half()
  404. elif config.dtype == torch.bfloat16:
  405. self.module.bfloat16()
  406. elif config.dtype == torch.float:
  407. self.module.float()
  408. def _create_cuda_graph(self, *inputs, **kwargs):
  409. # warmup to create the workspace and cublas handle
  410. cuda_stream = get_accelerator().Stream()
  411. cuda_stream.wait_stream(get_accelerator().current_stream())
  412. with get_accelerator().stream(cuda_stream):
  413. for i in range(3):
  414. ret = self.module(*inputs, **kwargs)
  415. get_accelerator().current_stream().wait_stream(cuda_stream)
  416. # create cuda_graph and assign static_inputs and static_outputs
  417. self._cuda_graphs = torch.cuda.CUDAGraph()
  418. self.static_inputs = inputs
  419. self.static_kwargs = kwargs
  420. with torch.cuda.graph(self._cuda_graphs):
  421. self.static_output = self.module(*self.static_inputs, **self.static_kwargs)
  422. self.cuda_graph_created = True
  423. def _graph_replay(self, *inputs, **kwargs):
  424. for i in range(len(inputs)):
  425. if torch.is_tensor(inputs[i]):
  426. self.static_inputs[i].copy_(inputs[i])
  427. for k in kwargs:
  428. if torch.is_tensor(kwargs[k]):
  429. self.static_kwargs[k].copy_(kwargs[k])
  430. self._cuda_graphs.replay()
  431. return self.static_output
  432. def model_times(self):
  433. assert self.model_profile_enabled, "model profiling is not enabled"
  434. model_times = self._model_times
  435. if self._config.enable_cuda_graph and len(self._model_times) == 0:
  436. raise ValueError(
  437. "Model times are empty and cuda graph is enabled. If "
  438. "this is a GPT-style model this combo is not supported. If this is a "
  439. "BERT-style model this is a bug, please report it. "
  440. f"Model type is: {type(self.module)}")
  441. self._model_times = []
  442. return model_times
  443. def _module_match(self, module):
  444. for policy in generic_policies:
  445. policy = policy()
  446. if policy.match_replaced(module):
  447. return True
  448. return False
  449. def _local_cuda_graph_used(self, module):
  450. if isinstance(module, torch.nn.Module):
  451. return False
  452. else:
  453. sub_module_cuda_graph = False
  454. for name in module.__dict__.keys():
  455. sub_module = getattr(module, name)
  456. if self._module_match(sub_module) and hasattr(sub_module,
  457. "enable_cuda_graph"):
  458. sub_module_cuda_graph = True
  459. return sub_module_cuda_graph
  460. def forward(self, *inputs, **kwargs):
  461. """Execute forward propagation
  462. Arguments:
  463. *inputs: Variable length input list
  464. **kwargs: variable length keyword arguments
  465. """
  466. start = None
  467. if self.model_profile_enabled and get_accelerator().device_name(
  468. ) == 'cuda' and self._config.enable_cuda_graph:
  469. get_accelerator().synchronize()
  470. start = time.time()
  471. if get_accelerator().device_name(
  472. ) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph:
  473. if self.cuda_graph_created:
  474. outputs = self._graph_replay(*inputs, **kwargs)
  475. else:
  476. self._create_cuda_graph(*inputs, **kwargs)
  477. outputs = self._graph_replay(*inputs, **kwargs)
  478. else:
  479. outputs = self.module(*inputs, **kwargs)
  480. if self.model_profile_enabled and self._config.enable_cuda_graph:
  481. get_accelerator().synchronize()
  482. duration = time.time() - start
  483. self._model_times.append(duration)
  484. return outputs
  485. def _generate(self, *inputs, **kwargs):
  486. # Reset KV-cache at the beginning of generate
  487. if hasattr(self.module, 'reset_cache'):
  488. self.module.reset_cache()
  489. num_beams = 1
  490. if "generation_config" in kwargs:
  491. gen_config = kwargs["generation_config"]
  492. num_beams = getattr(gen_config, "num_beams", 1)
  493. if "num_beams" in kwargs:
  494. num_beams = kwargs["num_beams"]
  495. if num_beams > 1:
  496. raise NotImplementedError(
  497. "DeepSpeed does not support `num_beams` > 1, if this is important to you please "
  498. "add your request to: https://github.com/microsoft/DeepSpeed/issues/2506"
  499. )
  500. return self.module.generate(*inputs, **kwargs)