hybrid_engine.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.inference.config import DeepSpeedInferenceConfig
  6. from deepspeed.module_inject.replace_policy import replace_policies
  7. from deepspeed.module_inject.utils import policy_to_ds_container
  8. from .engine import DeepSpeedEngine
  9. from .utils import TLinear, get_inactive_params
  10. from deepspeed.runtime.zero import GatheredParameters
  11. import time
  12. import gc
  13. import math
  14. from deepspeed import comm as dist
  15. from deepspeed.accelerator import get_accelerator
  16. from torch import nn
  17. from deepspeed.utils import logger
  18. from deepspeed.ops.op_builder import InferenceBuilder
  19. from deepspeed.module_inject.layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
  20. try:
  21. import transformers
  22. OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
  23. except:
  24. OPTLearnedPositionalEmbedding = None
  25. inference_cuda_module = None
  26. class DeepSpeedHybridEngine(DeepSpeedEngine):
  27. r"""DeepSpeed engine for training and inference."""
  28. inference_mp_group = None
  29. def __init__(self, args, model, **kwargs):
  30. super().__init__(args, model, **kwargs)
  31. # synch seed between all GPUs
  32. _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name())
  33. dist.broadcast(_rng_state, 0)
  34. get_accelerator().set_rng_state(_rng_state.cpu())
  35. self.Z3_enabled = (self._config.zero_config.stage == 3)
  36. self.gather_all_layers = self._config.hybrid_engine.pin_parameters
  37. # inference containers / fwds
  38. self._inference_containers = []
  39. self._orig_modules = []
  40. self._orig_fwds = []
  41. self.create_inference_module()
  42. # Performance stats
  43. self._t_start = None
  44. self._total_latency = 0
  45. self._iters = 0
  46. self._training_start_time = None
  47. self._generate_latency = 0
  48. self._training_latency = 0
  49. self._total_batch_size = None
  50. self._gather_latency = 0
  51. global inference_cuda_module
  52. if inference_cuda_module is None:
  53. builder = InferenceBuilder()
  54. inference_cuda_module = builder.load()
  55. self.is_lora_fused = False
  56. def convert_to_linear_transposed(self, model):
  57. def _replace_linear_layer(r_module, parent_type=None, prev_type=None):
  58. for name, child in r_module.named_children():
  59. if child.__class__ in [torch.nn.Linear] and \
  60. (parent_type is torch.nn.ModuleList or prev_type is torch.nn.ModuleList):
  61. setattr(r_module, name, TLinear(child, name))
  62. else:
  63. _replace_linear_layer(child, type(r_module), prev_type=parent_type)
  64. return r_module
  65. _replace_linear_layer(model)
  66. def new_inference_container(self, orig_layer, policy_cls, layer_id):
  67. policy = policy_cls(orig_layer, inference=True)
  68. if self._config.fp16_enabled:
  69. inference_dtype = torch.float16
  70. elif self._config.bfloat16_enabled:
  71. inference_dtype = torch.bfloat16
  72. else:
  73. inference_dtype = torch.float32
  74. _container = policy_to_ds_container(
  75. policy=policy,
  76. config=DeepSpeedInferenceConfig(
  77. set_empty_params=True,
  78. dtype=inference_dtype,
  79. max_out_tokens=self._config.hybrid_engine.max_out_tokens,
  80. min_out_tokens=self._config.hybrid_engine.max_out_tokens,
  81. transposed_mode=True,
  82. ),
  83. model_config=self.module.config if hasattr(self.module, 'config') else None,
  84. layer_id=layer_id,
  85. child=orig_layer)
  86. if self.mpu is not None:
  87. if hasattr(self.mpu, 'get_model_parallel_world_size'):
  88. _container.set_tensor_parallel_config(self.mpu.get_model_parallel_world_size(),
  89. self.mpu.get_model_parallel_group())
  90. else:
  91. _container.set_tensor_parallel_config(self.mpu.get_tensor_model_parallel_world_size(),
  92. self.mpu.get_tensor_model_parallel_group())
  93. else:
  94. _container.set_tensor_parallel_config(self._config.hybrid_engine.inference_tp_size, self.mp_group)
  95. _container.initialize_tensors(enable_training=True)
  96. _container.create_ds_model_config()
  97. _container.create_module()
  98. _container.set_params_wo_copy(Z3_enabled=self.Z3_enabled)
  99. return _container
  100. def populate_all_inference_policies(self):
  101. self.inference_policies = {}
  102. for plcy in replace_policies:
  103. _ = plcy(None)
  104. if isinstance(plcy._orig_layer_class, list):
  105. for orig_layer_class in plcy._orig_layer_class:
  106. self.inference_policies.update({orig_layer_class: (self.new_inference_container, plcy)})
  107. elif plcy._orig_layer_class is not None:
  108. self.inference_policies.update({plcy._orig_layer_class: (self.new_inference_container, plcy)})
  109. self.inference_policies.update({
  110. nn.Linear: (LinearLayer, ),
  111. nn.Embedding: (EmbeddingLayer, ),
  112. nn.LayerNorm: (Normalize, ),
  113. OPTLearnedPositionalEmbedding: (OPTEmbedding, )
  114. })
  115. def _fuse_lora_layer(self, layer_id):
  116. self._inference_containers[layer_id].fuse_lora()
  117. def fuse_lora_weight(self):
  118. for layer_id in range(len(self.layer_params)):
  119. self._fuse_lora_layer(layer_id)
  120. def _unfuse_lora_layer(self, layer_id):
  121. self._inference_containers[layer_id].unfuse_lora()
  122. def unfuse_lora_weight(self):
  123. for layer_id in range(len(self.layer_params)):
  124. self._unfuse_lora_layer(layer_id)
  125. def unfuse_lora_weight_non_pinned(self):
  126. for layer_id in range(len(self.layer_params)):
  127. non_active_params = get_inactive_params(self.layer_params[layer_id])
  128. non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
  129. non_active_params.extend(non_active_lora_params)
  130. with GatheredParameters(non_active_params):
  131. self._unfuse_lora_layer(layer_id)
  132. def retake_inference_cache(self):
  133. if self._config.hybrid_engine.release_inference_cache:
  134. retake_success = inference_cuda_module.retake_workspace()
  135. if not retake_success:
  136. logger.warning("Unable to acquire workspace on first attempt, emptying cache and retrying.")
  137. gc.collect()
  138. get_accelerator().empty_cache()
  139. retake_success = inference_cuda_module.retake_workspace()
  140. if not retake_success:
  141. raise RuntimeError("Unable to retake inference workspace.")
  142. def generate(self, *inputs, **kwargs):
  143. if self._total_batch_size is None:
  144. bsz = inputs[0].shape[0] if len(inputs) > 0 else \
  145. kwargs['input_ids'].shape[0]
  146. self._total_batch_size = bsz * dist.get_world_size()
  147. self._t0 = time.time()
  148. if self.Z3_enabled and self.gather_all_layers:
  149. if self._config.hybrid_engine.inference_tp_size > 1:
  150. non_tp_params = []
  151. for other_layer in self._other_layers:
  152. non_tp_params.extend(list(other_layer.parameters()))
  153. partition_size = self._config.hybrid_engine.tp_gather_partition_size
  154. layer_groups = math.ceil(len(self.layer_params) / partition_size)
  155. for lg in range(layer_groups):
  156. non_active_params = []
  157. non_active_lora_params = []
  158. for layer_id in range(lg * partition_size, min(len(self.layer_params), (lg + 1) * partition_size),
  159. 1):
  160. non_tp_params.extend(self.layer_params[layer_id][:4])
  161. non_active_params.extend(get_inactive_params(self.layer_params[layer_id]))
  162. non_active_params.extend(get_inactive_params(self.layer_lora_params[layer_id]))
  163. with GatheredParameters(non_active_params):
  164. for layer_id in range(lg * partition_size,
  165. min(len(self.layer_params), (lg + 1) * partition_size), 1):
  166. if len(self.all_lora_params) > 0:
  167. self._fuse_lora_layer(layer_id)
  168. if self.mpu is not None:
  169. self._inference_containers[layer_id].apply_tensor_parallelism(self.mp_replace,
  170. reversed_dim=True)
  171. # TODO(cmikeh2) Evaluate if this can be deferred when release_inference_cache
  172. # is enabled.
  173. gc.collect()
  174. get_accelerator().empty_cache()
  175. self._gather_latency = time.time() - self._t0
  176. input_shape = inputs[0].shape if len(inputs) > 0 else \
  177. kwargs['input_ids'].shape
  178. output = torch.zeros(
  179. (input_shape[0] * self._config.hybrid_engine.inference_tp_size, ) + input_shape[1:],
  180. dtype=inputs[0].dtype if len(inputs) > 0 else kwargs['input_ids'].dtype,
  181. device=inputs[0].device if len(inputs) > 0 else kwargs['input_ids'].device)
  182. input_cont = inputs[0].contiguous() if len(inputs) > 0 else kwargs['input_ids'].contiguous()
  183. dist.all_gather_into_tensor(output, input_cont, group=self.mp_group)
  184. if len(inputs) > 0:
  185. inputs = (output, *inputs[1:])
  186. else:
  187. kwargs['input_ids'] = output
  188. self.retake_inference_cache()
  189. non_active_params = get_inactive_params(non_tp_params)
  190. with GatheredParameters(non_active_params):
  191. generate_ret_vals = self._generate(*inputs, **kwargs)
  192. for layer_id in range(len(self.layer_params)):
  193. self._inference_containers[layer_id].release_memory()
  194. rank = dist.get_rank(group=self.mp_group)
  195. generate_ret_vals = generate_ret_vals[input_shape[0] * rank:input_shape[0] * (rank + 1)]
  196. else:
  197. non_active_layers = get_inactive_params(self.all_layers_params)
  198. non_active_lora_params = get_inactive_params(self.all_lora_params)
  199. non_active_layers.extend(non_active_lora_params)
  200. with GatheredParameters(non_active_layers):
  201. self._gather_latency = time.time() - self._t0
  202. if len(self.all_lora_params) > 0:
  203. self.fuse_lora_weight()
  204. self.retake_inference_cache()
  205. generate_ret_vals = self._generate(*inputs, **kwargs)
  206. if len(self.all_lora_params) > 0:
  207. self.unfuse_lora_weight()
  208. else:
  209. if len(self.all_lora_params) > 0 and (not self.Z3_enabled):
  210. self.fuse_lora_weight()
  211. self.retake_inference_cache()
  212. generate_ret_vals = self._generate(*inputs, **kwargs)
  213. if len(self.all_lora_params) > 0:
  214. if (not self.Z3_enabled):
  215. self.unfuse_lora_weight()
  216. else:
  217. self.unfuse_lora_weight_non_pinned()
  218. self.is_lora_fused = False
  219. if self._config.hybrid_engine.release_inference_cache:
  220. inference_cuda_module.release_workspace()
  221. gc.collect()
  222. get_accelerator().empty_cache()
  223. self._generate_latency = time.time() - self._t0 - self._gather_latency
  224. return generate_ret_vals
  225. def create_inference_containers(self, module, layer_id=0):
  226. for name, child in module.named_children():
  227. if child.__class__ in self.inference_policies:
  228. if self.inference_policies[child.__class__][0] == self.new_inference_container:
  229. self._inference_containers.append(self.inference_policies[child.__class__][0](
  230. child, self.inference_policies[child.__class__][-1], layer_id))
  231. self._orig_modules.append(child)
  232. self._orig_fwds.append(child.forward)
  233. self.layer_params.append(self._inference_containers[layer_id].get_all_params())
  234. self.lora_params.append(self._inference_containers[layer_id].get_lora_params())
  235. self.layer_lora_params.append([])
  236. for lora_param in self.lora_params[layer_id]:
  237. self.layer_lora_params[layer_id].extend(lora_param[:-1])
  238. self.all_lora_params.extend(lora_param[:-1])
  239. layer_id += 1
  240. else:
  241. self._other_layers.append(self.inference_policies[child.__class__][0](
  242. weight=child.weight, bias=child.bias if hasattr(child, 'bias') else None))
  243. self._orig_modules_others.append(child)
  244. self._orig_fwds_others.append(child.forward)
  245. else:
  246. self.create_inference_containers(child, layer_id=layer_id)
  247. def create_inference_module(self):
  248. self.layer_params = []
  249. self.layer_lora_params = []
  250. self.lora_params = []
  251. self.all_lora_params = []
  252. self._other_layers = []
  253. self._orig_modules_others = []
  254. self._orig_fwds_others = []
  255. if self._config.hybrid_engine.inference_tp_size > 1:
  256. if self.mpu is None:
  257. global_rank = dist.get_rank()
  258. world_size = dist.get_world_size()
  259. mp_group_id = global_rank // self._config.hybrid_engine.inference_tp_size
  260. num_mp_groups = world_size // self._config.hybrid_engine.inference_tp_size
  261. for mp_group_id in range(num_mp_groups):
  262. ranks = list(
  263. range(mp_group_id * self._config.hybrid_engine.inference_tp_size, \
  264. (mp_group_id + 1) * self._config.hybrid_engine.inference_tp_size, \
  265. 1)
  266. )
  267. mp_group = dist.new_group(ranks)
  268. if global_rank in ranks:
  269. # mp_group is used for broader collective
  270. self.mp_group = mp_group
  271. # mp_replace is used for container tensor slicing
  272. from deepspeed.module_inject import ReplaceWithTensorSlicing
  273. self.mp_replace = ReplaceWithTensorSlicing(
  274. mp_group=self.mp_group,
  275. mp_size=self._config.hybrid_engine.inference_tp_size,
  276. out_dim=0,
  277. in_dim=1)
  278. else:
  279. self.mp_group = self.mpu.get_model_parallel_group() if hasattr(self.mpu, 'get_model_parallel_group') else \
  280. self.mpu.get_tensor_model_parallel_group()
  281. from deepspeed.module_inject import ReplaceWithTensorSlicing
  282. self.mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group,
  283. mp_size=self._config.hybrid_engine.inference_tp_size,
  284. out_dim=0,
  285. in_dim=1)
  286. else:
  287. self.mp_group = None
  288. self.mp_replace = None
  289. self.populate_all_inference_policies()
  290. self.all_layers_params = list(self.module.parameters())
  291. self.create_inference_containers(self.module)
  292. if len(self._inference_containers) > 0:
  293. self._generate = self.module.generate
  294. self.module.generate = self.generate
  295. self._t0 = time.time()
  296. def _zero3_forward(self, layer_id):
  297. def run_forward(*inputs, **kwargs):
  298. non_active_params = get_inactive_params(self.layer_params[layer_id])
  299. non_active_lora_params = get_inactive_params(self.layer_lora_params[layer_id])
  300. non_active_params.extend(non_active_lora_params)
  301. with GatheredParameters(non_active_params):
  302. if len(self.all_lora_params) > 0:
  303. # Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory
  304. if not self.is_lora_fused:
  305. self._fuse_lora_layer(layer_id)
  306. # Set the is_lora_fused to true when reaching the last layer
  307. if layer_id == len(self.layer_params) - 1:
  308. self.is_lora_fused = True
  309. return self._inference_containers[layer_id].module.forward(*inputs, **kwargs)
  310. return run_forward
  311. def eval(self):
  312. if self._t_start is not None:
  313. latency = time.time() - self._t_start
  314. self._total_latency = self._total_latency + latency
  315. self._iters = self._iters + 1
  316. if not dist.is_initialized() or dist.get_rank() == 0:
  317. if self._total_batch_size is not None:
  318. cur_samples_p_sec = f'|CurSamplesPerSec={(1 / latency * self._total_batch_size):.2f} '
  319. avg_samples_p_sec = f'|AvgSamplesPerSec={(1 / (self._total_latency / self._iters) * self._total_batch_size):.2f}'
  320. else:
  321. cur_samples_p_sec = ''
  322. avg_samples_p_sec = ''
  323. others = latency - (self._generate_latency + self._training_latency)
  324. print(f'|E2E latency={(latency):.2f}s ' + \
  325. f'|Gather latency={self._gather_latency:.2f}s ({(self._gather_latency / latency * 100):.2f}%) '
  326. f'|Generate time={(self._generate_latency):.2f}s ({(self._generate_latency / latency * 100):.2f}%) ' + \
  327. f'|Training time={(self._training_latency):.2f}s ({(self._training_latency / latency * 100):.2f}%) ' + \
  328. f'|Others={others:.2f} ({(others / latency * 100):.2f}%)' + \
  329. cur_samples_p_sec + \
  330. avg_samples_p_sec)
  331. self._t_start = time.time()
  332. self._training_latency = 0
  333. super().eval()
  334. if len(self._inference_containers) > 0:
  335. for i, (orig_module, inference_container) in enumerate(zip(self._orig_modules,
  336. self._inference_containers)):
  337. if self.Z3_enabled and not self.gather_all_layers:
  338. orig_module.forward = self._zero3_forward(i)
  339. else:
  340. orig_module.forward = inference_container.module.forward
  341. inference_container.transform_for_inference()
  342. if not self.Z3_enabled or self.gather_all_layers:
  343. for orig_module, inference_layer in zip(self._orig_modules_others, self._other_layers):
  344. orig_module.forward = inference_layer.forward
  345. if self.Z3_enabled:
  346. gc.collect()
  347. get_accelerator().empty_cache()
  348. if self._t_start is None:
  349. self._t_start = time.time()
  350. def train(self, mode=True):
  351. if mode and len(self._orig_modules) > 0:
  352. for inference_container, orig_module, orig_fwd in zip(self._inference_containers, self._orig_modules,
  353. self._orig_fwds):
  354. inference_container.transform_for_training()
  355. orig_module.forward = orig_fwd
  356. for orig_module, orig_fwd in zip(self._orig_modules_others, self._orig_fwds_others):
  357. orig_module.forward = orig_fwd
  358. super().train(mode)
  359. if mode:
  360. self._training_start_time = time.time()
  361. def step(self, lr_kwargs=None):
  362. super().step(lr_kwargs=lr_kwargs)
  363. if len(self._inference_containers) > 0:
  364. if not self.Z3_enabled:
  365. for inference_container in self._inference_containers:
  366. inference_container.reset_params()
  367. if self._training_start_time is not None:
  368. self._training_latency += (time.time() - self._training_start_time)
  369. self._training_start_time = time.time()