replace_module.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import os
  3. import torch
  4. import tqdm
  5. import deepspeed
  6. import deepspeed.ops.transformer as transformer_inference
  7. from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
  8. from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
  9. from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
  10. from deepspeed.accelerator import get_accelerator
  11. from .replace_policy import HFGPT2LayerPolicy
  12. from .replace_policy import replace_policies, generic_policies
  13. from deepspeed import comm as dist
  14. from torch import nn
  15. from .layers import LinearAllreduce, LinearLayer
  16. from .load_checkpoint import load_model_with_checkpoint
  17. import time
  18. from .utils import policy_to_ds_container
  19. class ReplaceWithTensorSlicing:
  20. def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
  21. if mp_group is not None:
  22. self.gpu_index = dist.get_rank(group=mp_group)
  23. else:
  24. self.gpu_index = 0
  25. self.out_dim = out_dim
  26. self.in_dim = in_dim
  27. self.mp_size = mp_size
  28. def merge_assert(self, dim1, dim2):
  29. assert dim1 > dim2, \
  30. 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\
  31. for merging your checkpoints before replacing the transformer layer with\
  32. inference-kernels'
  33. def qkv_copy(self, dst, src, int8=False):
  34. if src is None:
  35. return src
  36. src_shape = src.shape
  37. dst_shape = dst.shape
  38. outer_dim = 0 if int8 else -1
  39. inner_dim = -1 if int8 else 0
  40. src_split = torch.split(src.data, src.shape[outer_dim] // 3, dim=outer_dim)
  41. if (len(src_shape) == 2 and len(dst_shape) == 2):
  42. if src_shape[outer_dim] == dst_shape[self.out_dim]:
  43. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  44. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  45. if hasattr(src, 'scale'):
  46. dst.scale = src.scale
  47. return dst
  48. if self.out_dim == 1:
  49. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  50. qkv_size = dst_shape[self.out_dim] // 3
  51. qkv_split = [
  52. torch.split(src_s,
  53. qkv_size,
  54. dim=outer_dim) for src_s in src_split
  55. ]
  56. weight_split = [
  57. torch.cat([qkv_s[i] for qkv_s in qkv_split],
  58. axis=outer_dim) for i in range(len(qkv_split[0]))
  59. ]
  60. dst = dst.reshape(-1).data.copy_(
  61. weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
  62. weight_split[self.gpu_index].shape)
  63. else:
  64. dst.data.copy_(src_split[self.gpu_index].to(
  65. get_accelerator().current_device_name()).contiguous())
  66. else:
  67. if src_shape[0] == dst_shape[0]:
  68. return torch.nn.parameter.Parameter(src)
  69. if self.out_dim == 1:
  70. qkv_size = dst_shape[0] // 3
  71. qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
  72. bias_split = [
  73. torch.cat([qkv_s[i] for qkv_s in qkv_split],
  74. axis=0) for i in range(len(qkv_split[0]))
  75. ]
  76. dst.data.copy_(bias_split[self.gpu_index].contiguous())
  77. else:
  78. dst.data.copy_(src_split[self.gpu_index].contiguous())
  79. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  80. if hasattr(src, 'scale'):
  81. dst.scale = src.scale
  82. return dst
  83. def copy(self, dst, src, int8=False):
  84. if src is None:
  85. return src
  86. assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
  87. outer_dim = 0 if int8 else 1
  88. inner_dim = 1 if int8 else 0
  89. src_shape = src.shape
  90. dst_shape = dst.shape
  91. if (len(src_shape) == 2 and len(dst_shape) == 2):
  92. if src_shape[inner_dim] == dst_shape[
  93. self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
  94. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  95. else:
  96. if src_shape[inner_dim] != dst_shape[self.in_dim]:
  97. self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
  98. weight_split = torch.split(
  99. src,
  100. dst_shape[self.in_dim],
  101. dim=inner_dim)[self.gpu_index].contiguous()
  102. else:
  103. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  104. weight_split = torch.split(
  105. src.data,
  106. dst_shape[self.out_dim],
  107. dim=outer_dim)[self.gpu_index].contiguous()
  108. dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(
  109. weight_split.shape)
  110. else:
  111. if src_shape[0] == dst_shape[0]:
  112. dst.data.copy_(src)
  113. else:
  114. bias_split = torch.split(src.data,
  115. dst_shape[-1])[self.gpu_index].contiguous()
  116. dst.data.copy_(bias_split)
  117. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  118. if hasattr(src, 'scale'):
  119. dst.scale = src.scale
  120. return dst
  121. def get_transformer_name(replaced_module):
  122. from .containers import supported_models
  123. from torch.nn import ModuleList
  124. transformer_name = ''
  125. for n, c in replaced_module.named_children():
  126. if c.__class__ in supported_models:
  127. transformer_name += n + '.'
  128. for name, child in c.named_children():
  129. if child.__class__ is ModuleList:
  130. transformer_name += name
  131. break
  132. break
  133. return transformer_name
  134. class GroupQuantizer:
  135. def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
  136. self.group_size = group_size
  137. self.num_bits = num_bits
  138. self.q_int8 = q_int8
  139. self.num_groups = num_groups
  140. def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
  141. if not self.q_int8 or not qkv:
  142. inputs = torch.nn.Parameter(inputs, requires_grad=False)
  143. inputs.scale = torch.empty(1)
  144. return inputs
  145. q_range = 2**self.num_bits
  146. num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[
  147. 0] // self.group_size
  148. inputs = inputs.to(get_accelerator().current_device_name())
  149. input_flat = inputs.reshape(num_groups, -1).contiguous()
  150. input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
  151. input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
  152. scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
  153. input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
  154. inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
  155. out = torch.nn.Parameter(inputs_q, requires_grad=False)
  156. inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
  157. input_flat = [
  158. inputs_split[i].reshape(num_groups,
  159. -1).contiguous() for i in range(2)
  160. ]
  161. input_min = [
  162. torch.min(input_flat[i],
  163. dim=1,
  164. keepdim=True)[0].float() for i in range(2)
  165. ]
  166. input_max = [
  167. torch.max(input_flat[i],
  168. dim=1,
  169. keepdim=True)[0].float() for i in range(2)
  170. ]
  171. scale1 = [
  172. (torch.max(input_min[i].abs(),
  173. input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
  174. for i in range(2)
  175. ]
  176. out.scale = torch.cat([scale.squeeze().unsqueeze(0),
  177. scale1[0],
  178. scale1[1]],
  179. dim=0).reshape(num_groups,
  180. -1).contiguous()
  181. return out
  182. def _module_match(module):
  183. for policy in generic_policies:
  184. policy = policy()
  185. if policy.match(module):
  186. return policy
  187. return None
  188. def generic_injection(module, fp16=False, enable_cuda_graph=True):
  189. def replace_attn(child, policy):
  190. policy_attn = policy.attention(child)
  191. if policy_attn is None:
  192. return child
  193. if len(policy_attn) == 5:
  194. qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
  195. else:
  196. qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn
  197. config = transformer_inference.DeepSpeedInferenceConfig(
  198. hidden_size=hidden_size,
  199. heads=heads,
  200. fp16=fp16,
  201. triangular_masking=False,
  202. max_out_tokens=4096,
  203. )
  204. attn_module = DeepSpeedDiffusersAttention(config)
  205. def transpose(data):
  206. data = data.contiguous()
  207. data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
  208. data = data.reshape(data.shape[-1], data.shape[-2])
  209. data.to(get_accelerator().current_device_name())
  210. return data
  211. if len(policy_attn) == 5:
  212. attn_module.attn_qkvw.data = transpose(qkvw.data)
  213. else:
  214. attn_module.attn_qkvw = None
  215. attn_module.attn_qw.data = transpose(qw.data)
  216. attn_module.attn_kw.data = transpose(kw.data)
  217. attn_module.attn_vw.data = transpose(vw.data)
  218. attn_module.attn_qkvb = None
  219. attn_module.attn_ow.data = transpose(attn_ow.data)
  220. attn_module.attn_ob.data.copy_(
  221. attn_ob.data.to(get_accelerator().current_device_name()))
  222. return attn_module
  223. def replace_attn_block(child, policy):
  224. config = Diffusers2DTransformerConfig()
  225. return DeepSpeedDiffusersTransformerBlock(child, config)
  226. if isinstance(module, torch.nn.Module):
  227. pass
  228. else:
  229. if fp16 is False:
  230. raise ValueError("Generic injection only supported with FP16")
  231. try:
  232. import diffusers
  233. cross_attention = diffusers.models.attention.CrossAttention
  234. attention_block = diffusers.models.attention.BasicTransformerBlock
  235. new_policies = {
  236. cross_attention: replace_attn,
  237. attention_block: replace_attn_block,
  238. }
  239. except ImportError:
  240. new_policies = {}
  241. #replace_transformer_layer(None,
  242. # module.text_encoder,
  243. # training=False,
  244. # replace_with_kernel_inject=True,
  245. # triangular_masking=True,
  246. # max_out_tokens=8192)
  247. from ..model_implementations.transformers.clip_encoder import DSClipEncoder
  248. cg_encoder = DSClipEncoder(module.text_encoder,
  249. enable_cuda_graph=enable_cuda_graph)
  250. setattr(module, 'text_encoder', cg_encoder)
  251. for name in module.__dict__.keys():
  252. sub_module = getattr(module, name)
  253. policy = _module_match(sub_module)
  254. if policy is not None:
  255. def _replace_module(module, policy):
  256. for name, child in module.named_children():
  257. _replace_module(child, policy)
  258. if child.__class__ in new_policies:
  259. replaced_module = new_policies[child.__class__](child,
  260. policy)
  261. setattr(module, name, replaced_module)
  262. _replace_module(sub_module, policy)
  263. new_module = policy.apply(sub_module,
  264. enable_cuda_graph=enable_cuda_graph)
  265. print(f"**** found and replaced {name} w. {type(new_module)}")
  266. setattr(module, name, new_module)
  267. container_g = None
  268. def replace_transformer_layer(orig_layer_impl,
  269. model,
  270. checkpoint_dict,
  271. config,
  272. model_config):
  273. """ Replace bert-style transformer layers with DeepSpeed's transformer layer
  274. Arguments:
  275. orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
  276. e.g., transformers.modeling_bert.BertLayer.
  277. model (torch.nn.Module): user's nn.module representing their model
  278. checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine
  279. config: top-level DS Inference config defined in inference/config.py
  280. model_config: HuggingFace model config passed from the inference/engine.py
  281. Returns:
  282. Updated nn.module with replaced transformer layers
  283. """
  284. # defining globals as internally defined functions inherit these everywhere
  285. fp16 = (config.dtype == torch.float16 or config.dtype == torch.int8)
  286. quantize = (config.dtype == torch.int8)
  287. # todo: Refactor later. In future, let's minimize the style used above and use config.** instead
  288. linear_layer_setting = None
  289. '''
  290. linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers
  291. '''
  292. micro_batch_size = -1
  293. seed = -1
  294. local_rank = -1
  295. mp_replace = ReplaceWithTensorSlicing(
  296. mp_group=config.tensor_parallel.tp_group,
  297. mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
  298. def replace_with_policy(child,
  299. policy_cls,
  300. triangular_masking,
  301. inference=False,
  302. layer_id=0):
  303. policy = policy_cls(child, inference=inference)
  304. if not policy.cuda_graph_supported:
  305. # policy says cuda graph is not supported raise an error if set
  306. assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
  307. from deepspeed.moe.layer import MoE
  308. moe = False
  309. if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
  310. num_experts = child.mlp.num_experts
  311. moe = True
  312. # 1. Create a model-specific container object using the policy object.
  313. _container = policy_to_ds_container(policy=policy,
  314. config=config,
  315. model_config=model_config,
  316. layer_id=layer_id,
  317. child=child)
  318. _container.set_dtype(fp16)
  319. _container.set_moe(moe)
  320. # 2. Set the tensor parallelism config
  321. _container.set_tensor_parallel_config(config.tensor_parallel.tp_size,
  322. config.tensor_parallel.tp_group)
  323. # 3. Initialize tensors
  324. _container.initialize_tensors()
  325. # 4. deal with data types -- needs refactor to use dtype instead of fp16
  326. if fp16:
  327. _container.convert_to_required_dtype(dtype=torch.half)
  328. # 5. Set the quantization config
  329. quantizer = GroupQuantizer(q_int8=quantize)
  330. _container.set_quantization_config(quantize, quantizer)
  331. # 6. create a DS Inference config object
  332. _container.create_ds_model_config()
  333. # 7. use the config and create the module
  334. _container.create_module()
  335. # 8. transpose the weights and bias if needed
  336. _container.transpose()
  337. # 9. deal with tensor parallelism.
  338. _container.apply_tensor_parallelism(mp_replace)
  339. # 10. copy the tensors from the model-specific container to the new module
  340. _container.copy_data_to_new_module()
  341. # 11. set global for generic checkpoint loading
  342. global container_g
  343. if container_g is None:
  344. container_g = _container
  345. return _container.module
  346. def replace_wo_policy(module, all_reduce_linears):
  347. mp_size = config.tensor_parallel.tp_size
  348. mp_group = config.tensor_parallel.tp_group
  349. def _replace(child, name, conv_linear_layer):
  350. mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
  351. weight_shape = child.weight.shape
  352. if name in all_reduce_linears:
  353. new_weight = torch.empty((
  354. weight_shape[1] if conv_linear_layer else weight_shape[0],
  355. (weight_shape[0] if conv_linear_layer else weight_shape[1]) //
  356. mp_size,
  357. ),
  358. device=child.weight.device,
  359. dtype=child.weight.dtype)
  360. if conv_linear_layer:
  361. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  362. data = mp_replace.copy(new_weight, child.weight.data)
  363. new_bias = torch.empty((weight_shape[0]),
  364. device=child.weight.device,
  365. dtype=child.weight.dtype)
  366. if child.bias is not None:
  367. new_bias.data.copy_(child.bias.data)
  368. return LinearAllreduce(data, child.bias if child.bias is None else \
  369. torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group)
  370. else:
  371. new_weight = torch.empty((
  372. (weight_shape[1] if conv_linear_layer else weight_shape[0]) //
  373. mp_size,
  374. weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1],
  375. ),
  376. device=child.weight.device,
  377. dtype=child.weight.dtype)
  378. if conv_linear_layer:
  379. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  380. data = mp_replace.copy(new_weight, child.weight.data)
  381. new_bias = torch.empty((weight_shape[0] // mp_size),
  382. device=child.weight.device,
  383. dtype=child.weight.dtype)
  384. bias_data = None if child.bias is None else mp_replace.copy(
  385. new_bias,
  386. child.bias.data).to(get_accelerator().current_device_name())
  387. return LinearLayer(weight=data.to(
  388. get_accelerator().current_device_name()),
  389. bias=bias_data)
  390. def _slice_embedding(child, name, conv_linear_layer):
  391. mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
  392. new_weight = torch.empty((child.weight.shape[0],
  393. child.weight.shape[1] // mp_size),
  394. device=child.weight.device,
  395. dtype=child.weight.dtype)
  396. data = mp_replace.copy(new_weight,
  397. child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
  398. child.weight.data)
  399. new_embedding = nn.Embedding(child.weight.shape[0],
  400. child.weight.shape[1] // mp_size)
  401. new_embedding.weight.data.copy_(data)
  402. return new_embedding
  403. def update_mp_params(child):
  404. if hasattr(child, 'n_heads'):
  405. child.n_heads = child.n_heads // mp_size
  406. if hasattr(child, 'inner_dim'):
  407. child.inner_dim = child.inner_dim // mp_size
  408. if hasattr(child, 'num_heads'):
  409. child.num_heads = child.num_heads // mp_size
  410. if hasattr(child, 'num_attention_heads'):
  411. child.num_attention_heads = child.num_attention_heads // mp_size
  412. if hasattr(child, 'num_attn_heads'):
  413. child.num_attn_heads = child.num_attn_heads // mp_size
  414. if hasattr(child, 'all_head_size'):
  415. child.all_head_size = child.all_head_size // mp_size
  416. if hasattr(child, 'embed_dim'):
  417. child.embed_dim = child.embed_dim // mp_size
  418. if hasattr(child, 'hidden_size'):
  419. child.hidden_size = child.hidden_size // mp_size
  420. conv_linear_layer = False
  421. if linear_layer_setting is not None:
  422. linear_policies = {linear_layer_setting[0]: _replace}
  423. if len(linear_layer_setting) == 2:
  424. linear_policies.update({linear_layer_setting[1]: _slice_embedding})
  425. else:
  426. if orig_layer_impl is HFGPT2LayerPolicy._orig_layer_class:
  427. try:
  428. import transformers
  429. conv_linear_layer = True
  430. linear_policies = {transformers.model_utils.Conv1D: _replace}
  431. except ImportError:
  432. linear_policies = {nn.Linear: _replace}
  433. else:
  434. linear_policies = {nn.Linear: _replace, nn.Embedding: _slice_embedding}
  435. def _replace_module(r_module, prev_name=''):
  436. for name, child in r_module.named_children():
  437. if child.__class__ in linear_policies:
  438. setattr(
  439. r_module,
  440. name,
  441. linear_policies[child.__class__](child,
  442. prev_name + '.' + name,
  443. conv_linear_layer))
  444. else:
  445. update_mp_params(child)
  446. _replace_module(child, name)
  447. return r_module
  448. return _replace_module(module)
  449. def replace_fn(child, _policy, layer_id=0):
  450. training = False # todo: refactor this part to go in the config
  451. if training:
  452. # copy relevant state from child -> new module
  453. new_module = replace_with_policy(child, _policy, config.triangular_masking)
  454. else:
  455. # copy relevant state from child -> new module
  456. if config.replace_with_kernel_inject:
  457. new_module = replace_with_policy(child,
  458. _policy,
  459. config.triangular_masking,
  460. inference=True,
  461. layer_id=layer_id)
  462. else:
  463. new_module = replace_wo_policy(child, _policy)
  464. return new_module
  465. replaced_module = replace_module(model=model,
  466. orig_class=orig_layer_impl,
  467. replace_fn=replace_fn,
  468. _replace_policy=config.injection_policy_tuple)
  469. quantizer = GroupQuantizer(q_int8=quantize)
  470. world_size = dist.get_world_size() if dist.is_initialized() else 1
  471. rank = dist.get_rank() if dist.is_initialized() else 0
  472. if checkpoint_dict is not None:
  473. assert container_g.ckpt_load_enabled, \
  474. f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
  475. start_time = time.time()
  476. checkpoint = checkpoint_dict['checkpoints']
  477. ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint
  478. ckpt_type = checkpoint_dict.get('parallelization', 'pp')
  479. ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list))
  480. ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size)
  481. base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
  482. if ckpt_type == 'pp' and type(checkpoint) is list:
  483. pbar = tqdm.tqdm(total=len(checkpoint),
  484. desc=f"Loading {len(checkpoint)} checkpoint shards")
  485. for i in range(len(checkpoint)):
  486. sd = [
  487. torch.load(os.path.join(base_dir1,
  488. checkpoint[i]),
  489. map_location='cpu')
  490. ]
  491. load_model_with_checkpoint(replaced_module,
  492. sd,
  493. mp_replace,
  494. ckpt_type,
  495. ckpt_mp_size,
  496. quantizer,
  497. container=container_g)
  498. pbar.update(1)
  499. else:
  500. import gc
  501. num_checkpoints = len(ckpt_list) // ckpt_mp_size
  502. tp_split_size = (world_size / ckpt_mp_size)
  503. sd_offset = int(rank / tp_split_size)
  504. sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
  505. pbar = tqdm.tqdm(total=num_checkpoints,
  506. desc=f"Loading {num_checkpoints} checkpoint shards")
  507. for i in range(num_checkpoints):
  508. pbar.update(1)
  509. ckpt_index = i * ckpt_mp_size + sd_offset
  510. ckpt_files = [
  511. os.path.join(base_dir1,
  512. ckpt_list[ckpt_index +
  513. j]) if base_dir1 else ckpt_list[ckpt_index +
  514. j]
  515. for j in range(sd_count)
  516. ]
  517. sds = [
  518. torch.load(ckpt_file,
  519. map_location='cpu') for ckpt_file in ckpt_files
  520. ]
  521. load_model_with_checkpoint(replaced_module,
  522. sds,
  523. mp_replace,
  524. ckpt_type,
  525. ckpt_mp_size,
  526. quantizer,
  527. int(rank % tp_split_size),
  528. container=container_g)
  529. sds = [None for _ in sds]
  530. gc.collect()
  531. if "non_tp" in checkpoint:
  532. pbar = tqdm.tqdm(
  533. total=len(checkpoint["non_tp"]),
  534. desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
  535. for i in range(len(checkpoint["non_tp"])):
  536. pbar.update(1)
  537. ckpt_file = os.path.join(base_dir1,
  538. checkpoint["non_tp"][i]
  539. ) if base_dir1 else checkpoint["non_tp"][i]
  540. sds = [torch.load(ckpt_file, map_location='cpu')]
  541. load_model_with_checkpoint(replaced_module,
  542. sds,
  543. mp_replace,
  544. ckpt_type,
  545. ckpt_mp_size,
  546. quantizer,
  547. int(rank % tp_split_size),
  548. container=container_g)
  549. sds = [None for _ in sds]
  550. gc.collect()
  551. print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
  552. if config.save_mp_checkpoint_path is not None:
  553. from collections import OrderedDict
  554. import json
  555. num_partitions = 8
  556. if checkpoint_dict is None:
  557. ckpt_name = "ds_model"
  558. try:
  559. from transformers.models.bloom.modeling_bloom import BloomForCausalLM
  560. if isinstance(model, BloomForCausalLM):
  561. ckpt_name = "bloom"
  562. except ImportError:
  563. ckpt_name = "ds_model"
  564. else:
  565. ckpt_name = checkpoint_dict['type']
  566. if dist.is_initialized():
  567. dist.barrier()
  568. transformer_name = get_transformer_name(replaced_module)
  569. non_tp_ckpt_name = f'non-tp.pt'
  570. ckpt_files = [non_tp_ckpt_name]
  571. os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
  572. if not dist.is_initialized() or dist.get_rank() == 0:
  573. print("Saving tp-sharded checkpoints")
  574. torch.save(
  575. OrderedDict({
  576. k: v
  577. for k,
  578. v in dict(replaced_module.state_dict()).items()
  579. if transformer_name not in k
  580. }),
  581. f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
  582. ckpt_config = json.dumps({
  583. 'type':
  584. ckpt_name,
  585. 'base_dir':
  586. f'{config.save_mp_checkpoint_path}',
  587. 'checkpoints': {
  588. "non_tp":
  589. ckpt_files,
  590. "tp": [
  591. f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions)
  592. for r in range(world_size)
  593. ]
  594. },
  595. 'version':
  596. 1.0,
  597. 'parallelization':
  598. 'tp',
  599. 'tp_size':
  600. world_size,
  601. 'dtype':
  602. 'int8' if quantize else ('float16' if fp16 else 'float32')
  603. })
  604. with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json",
  605. "w") as cfg:
  606. cfg.write(ckpt_config)
  607. rep_sd = replaced_module.state_dict()
  608. for n, p in replaced_module.named_parameters():
  609. if hasattr(p, 'scale'):
  610. rep_sd[n] = [p, p.scale]
  611. keys = list(rep_sd.keys())
  612. partition_size = (len(keys) // num_partitions + 1)
  613. for m in range(num_partitions):
  614. torch.save(
  615. OrderedDict({
  616. k: [rep_sd[k],
  617. rep_sd[k].scale] if hasattr(rep_sd[k],
  618. 'scale') else rep_sd[k]
  619. for k in keys[m * partition_size:(m + 1) * partition_size]
  620. if transformer_name in k
  621. }),
  622. f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
  623. return replaced_module
  624. def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
  625. """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer
  626. Arguments:
  627. orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
  628. e.g., transformers.modeling_bert.BertLayer.
  629. model (torch.nn.Module): user's nn.module representing their model
  630. config (dict): model config containing hidden size, attention heads, etc.
  631. Returns:
  632. Updated nn.module with original bert-style transformer layers
  633. """
  634. def replace_fn(child, _replace_policy, layer_id):
  635. #from turing.nvidia_modelingpreln import BertLayer
  636. orig_module = orig_layer_impl(config)
  637. # copy relevant state from child -> original module
  638. qkvw = child.attn_qkvw.data
  639. qkvb = child.attn_qkvb.data
  640. qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
  641. qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
  642. orig_module.attention.self.query.weight.data = qw
  643. orig_module.attention.self.query.bias.data = qb
  644. orig_module.attention.self.key.weight.data = kw
  645. orig_module.attention.self.key.bias.data = kb
  646. orig_module.attention.self.value.weight.data = vw
  647. orig_module.attention.self.value.bias.data = vb
  648. orig_module.attention.output.dense.weight.data = child.attn_ow.data
  649. orig_module.attention.output.dense.bias.data = child.attn_ob.data
  650. attn_ln_w = child.attn_nw.data
  651. attn_ln_b = child.attn_nb.data
  652. if preln:
  653. orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
  654. orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
  655. else:
  656. orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
  657. orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
  658. inter_ff_w = child.inter_w.data
  659. inter_ff_b = child.inter_b.data
  660. if preln:
  661. orig_module.intermediate.dense_act.weight.data = inter_ff_w
  662. orig_module.intermediate.dense_act.bias.data = inter_ff_b
  663. else:
  664. orig_module.intermediate.dense.weight.data = inter_ff_w
  665. orig_module.intermediate.dense.bias.data = inter_ff_b
  666. orig_module.output.dense.weight.data = child.output_w.data
  667. orig_module.output.dense.bias.data = child.output_b.data
  668. transformer_ln_w = child.norm_w.data
  669. transformer_ln_b = child.norm_b.data
  670. if preln:
  671. orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
  672. orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
  673. else:
  674. orig_module.output.LayerNorm.weight.data = transformer_ln_w
  675. orig_module.output.LayerNorm.bias.data = transformer_ln_b
  676. return orig_module
  677. return replace_module(model=model,
  678. orig_class=deepspeed.DeepSpeedTransformerLayer,
  679. replace_fn=replace_fn,
  680. _replace_policy=None)
  681. def replace_module(model, orig_class, replace_fn, _replace_policy):
  682. """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
  683. Arguments:
  684. model (torch.nn.Module): the model to augment
  685. orig_class (torch.nn.Module): the module to search for
  686. replace_fn (method): a method to convert instances of ``orig_class`` to the
  687. desired type and return a new instance.
  688. Returns:
  689. A modified ``model``.
  690. """
  691. policy = {}
  692. if orig_class is not None:
  693. policy.update({orig_class: (replace_fn, _replace_policy)})
  694. else:
  695. for plcy in replace_policies:
  696. # instantiate a throw-away policy in order to populate the _orig_layer_class
  697. _ = plcy(None)
  698. if isinstance(plcy._orig_layer_class, list):
  699. for orig_layer_class in plcy._orig_layer_class:
  700. policy.update({orig_layer_class: (replace_fn, plcy)})
  701. elif plcy._orig_layer_class is not None:
  702. policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
  703. assert len(policy.items()) > 0,\
  704. "No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
  705. "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
  706. replaced_module, _ = _replace_module(model, policy)
  707. return replaced_module
  708. from ..pipe import PipelineModule
  709. def _replace_module(model, policies, layer_id=0):
  710. """ Traverse model's children recursively and apply any transformations in ``policies``.
  711. Arguments:
  712. model (torch.nn.Module): model to augment
  713. policies (dict): Mapping of source class to replacement function.
  714. Returns:
  715. Modified ``model``.
  716. """
  717. for name, child in model.named_children():
  718. if child.__class__ in policies:
  719. replaced_module = policies[child.__class__][0](child,
  720. policies[child.__class__][-1],
  721. layer_id)
  722. setattr(model, name, replaced_module)
  723. if isinstance(model, PipelineModule):
  724. assert hasattr(model, 'forward_funcs'),\
  725. "we require pipe-module to have the list of fwd_functions"
  726. model.forward_funcs[model.fwd_map[name]] = replaced_module
  727. layer_id += 1
  728. else:
  729. _, layer_id = _replace_module(child, policies, layer_id=layer_id)
  730. # Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
  731. model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
  732. return model, layer_id