replace_module.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import tqdm
  7. import deepspeed
  8. import deepspeed.ops.transformer as transformer_inference
  9. from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
  10. from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
  11. from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
  12. from deepspeed.accelerator import get_accelerator
  13. from .replace_policy import replace_policies, generic_policies
  14. from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
  15. from deepspeed import comm as dist
  16. from .load_checkpoint import load_model_with_checkpoint
  17. import time
  18. from .utils import policy_to_ds_container
  19. import gc
  20. def get_transformer_name(replaced_module):
  21. from .containers import supported_models
  22. from torch.nn import ModuleList
  23. transformer_name = ''
  24. for n, c in replaced_module.named_children():
  25. if c.__class__ in supported_models:
  26. transformer_name += n + '.'
  27. for name, child in c.named_children():
  28. if child.__class__ is ModuleList:
  29. transformer_name += name
  30. break
  31. break
  32. return transformer_name
  33. class GroupQuantizer:
  34. def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
  35. self.group_size = group_size
  36. self.num_bits = num_bits
  37. self.q_int8 = q_int8
  38. self.num_groups = num_groups
  39. def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
  40. if not self.q_int8 or not qkv:
  41. inputs = torch.nn.Parameter(inputs, requires_grad=False)
  42. inputs.scale = torch.empty(1)
  43. return inputs
  44. q_range = 2**self.num_bits
  45. num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
  46. inputs = inputs.to(get_accelerator().current_device_name())
  47. input_flat = inputs.reshape(num_groups, -1).contiguous()
  48. input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
  49. input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
  50. scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
  51. input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
  52. inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
  53. out = torch.nn.Parameter(inputs_q, requires_grad=False)
  54. inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
  55. input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
  56. input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
  57. input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
  58. scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
  59. for i in range(2)]
  60. out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
  61. -1).contiguous()
  62. return out
  63. def _module_match(module):
  64. for policy in generic_policies:
  65. policy = policy()
  66. if policy.match(module):
  67. return policy
  68. return None
  69. def generic_injection(module, dtype=None, enable_cuda_graph=True):
  70. def replace_attn(child, policy):
  71. policy_attn = policy.attention(child)
  72. if policy_attn is None:
  73. return child
  74. if len(policy_attn) == 5:
  75. qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
  76. else:
  77. qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn
  78. config = transformer_inference.DeepSpeedInferenceConfig(
  79. hidden_size=hidden_size,
  80. heads=heads,
  81. dtype=dtype,
  82. triangular_masking=False,
  83. max_out_tokens=4096,
  84. )
  85. attn_module = DeepSpeedDiffusersAttention(config)
  86. def transpose(data):
  87. data = data.contiguous()
  88. data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
  89. data = data.reshape(data.shape[-1], data.shape[-2])
  90. data.to(get_accelerator().current_device_name())
  91. return data
  92. if len(policy_attn) == 5:
  93. attn_module.attn_qkvw.data = transpose(qkvw.data)
  94. else:
  95. attn_module.attn_qkvw = None
  96. attn_module.attn_qw.data = transpose(qw.data)
  97. attn_module.attn_kw.data = transpose(kw.data)
  98. attn_module.attn_vw.data = transpose(vw.data)
  99. attn_module.attn_qkvb = None
  100. attn_module.attn_ow.data = transpose(attn_ow.data)
  101. attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
  102. return attn_module
  103. def replace_attn_block(child, policy):
  104. config = Diffusers2DTransformerConfig()
  105. return DeepSpeedDiffusersTransformerBlock(child, config)
  106. if isinstance(module, torch.nn.Module):
  107. pass
  108. else:
  109. if dtype not in [torch.float16, torch.half]:
  110. raise ValueError("Generic injection only supported with FP16")
  111. try:
  112. import diffusers
  113. if hasattr(diffusers.models.attention, 'CrossAttention'):
  114. cross_attention = diffusers.models.attention.CrossAttention
  115. else:
  116. cross_attention = diffusers.models.attention_processor.Attention
  117. attention_block = diffusers.models.attention.BasicTransformerBlock
  118. new_policies = {
  119. cross_attention: replace_attn,
  120. attention_block: replace_attn_block,
  121. }
  122. except ImportError:
  123. new_policies = {}
  124. #replace_transformer_layer(None,
  125. # module.text_encoder,
  126. # training=False,
  127. # replace_with_kernel_inject=True,
  128. # triangular_masking=True,
  129. # max_out_tokens=8192)
  130. from ..model_implementations.transformers.clip_encoder import DSClipEncoder
  131. cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
  132. setattr(module, 'text_encoder', cg_encoder)
  133. for name in module.__dict__.keys():
  134. sub_module = getattr(module, name)
  135. policy = _module_match(sub_module)
  136. if policy is not None:
  137. def _replace_module(module, policy):
  138. for name, child in module.named_children():
  139. _replace_module(child, policy)
  140. if child.__class__ in new_policies:
  141. replaced_module = new_policies[child.__class__](child, policy)
  142. setattr(module, name, replaced_module)
  143. _replace_module(sub_module, policy)
  144. new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
  145. print(f"**** found and replaced {name} w. {type(new_module)}")
  146. setattr(module, name, new_module)
  147. container_g = None
  148. def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
  149. """ Replace bert-style transformer layers with DeepSpeed's transformer layer
  150. Arguments:
  151. orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
  152. e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
  153. model (torch.nn.Module): user's nn.module representing their model
  154. checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine
  155. config: top-level DS Inference config defined in inference/config.py
  156. model_config: HuggingFace model config passed from the inference/engine.py
  157. Returns:
  158. Updated nn.module with replaced transformer layers
  159. """
  160. # defining globals as internally defined functions inherit these everywhere
  161. quantize = (config.dtype == torch.int8)
  162. # todo: Refactor later. In future, let's minimize the style used above and use config.** instead
  163. linear_layer_setting = None
  164. '''
  165. linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers
  166. '''
  167. micro_batch_size = -1
  168. seed = -1
  169. local_rank = -1
  170. mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
  171. mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
  172. def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
  173. policy = policy_cls(child, inference=inference)
  174. if not policy.cuda_graph_supported:
  175. # policy says cuda graph is not supported raise an error if set
  176. assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
  177. from deepspeed.moe.layer import MoE
  178. moe = False
  179. if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
  180. num_experts = child.mlp.num_experts
  181. moe = True
  182. # 1. Create a model-specific container object using the policy object.
  183. _container = policy_to_ds_container(policy=policy,
  184. config=config,
  185. model_config=model_config,
  186. layer_id=layer_id,
  187. child=child)
  188. _container.set_moe(moe)
  189. # 2. Set the tensor parallelism config
  190. _container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
  191. # 3. Initialize tensors
  192. _container.initialize_tensors()
  193. # 4. deal with data types -- needs refactor to use dtype instead of fp16
  194. if config.dtype in [torch.float16, torch.bfloat16, torch.int8]:
  195. _container.convert_to_required_dtype()
  196. # 5. Set the quantization config
  197. quantizer = GroupQuantizer(q_int8=quantize)
  198. _container.set_quantization_config(quantizer)
  199. # 6. create a DS Inference config object
  200. _container.create_ds_model_config()
  201. # 7. use the config and create the module
  202. _container.create_module()
  203. # 8. transpose the weights and bias if needed
  204. _container.transpose()
  205. # 9. deal with tensor parallelism.
  206. _container.apply_tensor_parallelism(mp_replace)
  207. # 10. copy the tensors from the model-specific container to the new module
  208. _container.copy_data_to_new_module()
  209. # 11. set global for generic checkpoint loading
  210. global container_g
  211. if container_g is None:
  212. container_g = _container
  213. return _container.module
  214. def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
  215. #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
  216. # 1. Create AutoTP object
  217. _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
  218. # 2. Set the tensor parallelism config
  219. _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
  220. # 3. Set linear policies
  221. _autotp.update_linear_policies()
  222. # 4. Replace modules
  223. return _autotp._replace_module(module)
  224. def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
  225. training = False # todo: refactor this part to go in the config
  226. if training:
  227. # copy relevant state from child -> new module
  228. new_module = replace_with_policy(child, _policy, config.triangular_masking)
  229. else:
  230. # copy relevant state from child -> new module
  231. if config.replace_with_kernel_inject:
  232. new_module = replace_with_policy(child,
  233. _policy,
  234. config.triangular_masking,
  235. inference=True,
  236. layer_id=layer_id)
  237. else:
  238. new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
  239. return new_module
  240. if checkpoint_dict is not None and not config.replace_with_kernel_inject:
  241. # AutoTP shard loading
  242. checkpoint = checkpoint_dict["checkpoints"]
  243. pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
  244. for i in range(len(checkpoint)):
  245. checkpoint_file = os.path.join(config.base_dir, checkpoint[i])
  246. replaced_module = replace_module(model=model,
  247. orig_class=orig_layer_impl,
  248. replace_fn=replace_fn,
  249. _replace_policy=config.injection_policy_tuple,
  250. checkpoint=checkpoint_file)
  251. pbar.update(1)
  252. gc.collect()
  253. else:
  254. replaced_module = replace_module(model=model,
  255. orig_class=orig_layer_impl,
  256. replace_fn=replace_fn,
  257. _replace_policy=config.injection_policy_tuple)
  258. quantizer = GroupQuantizer(q_int8=quantize)
  259. world_size = dist.get_world_size() if dist.is_initialized() else 1
  260. rank = dist.get_rank() if dist.is_initialized() else 0
  261. if checkpoint_dict is not None and config.replace_with_kernel_inject:
  262. assert container_g.ckpt_load_enabled, \
  263. f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
  264. start_time = time.time()
  265. checkpoint = checkpoint_dict['checkpoints']
  266. ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint
  267. ckpt_type = checkpoint_dict.get('parallelization', 'pp')
  268. ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list))
  269. ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size)
  270. base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
  271. if ckpt_type == 'pp' and type(checkpoint) is list:
  272. pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
  273. for i in range(len(checkpoint)):
  274. sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
  275. load_model_with_checkpoint(replaced_module,
  276. sd,
  277. mp_replace,
  278. ckpt_type,
  279. ckpt_mp_size,
  280. quantizer,
  281. container=container_g)
  282. pbar.update(1)
  283. else:
  284. num_checkpoints = len(ckpt_list) // ckpt_mp_size
  285. tp_split_size = (world_size / ckpt_mp_size)
  286. sd_offset = int(rank / tp_split_size)
  287. sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
  288. pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
  289. for i in range(num_checkpoints):
  290. pbar.update(1)
  291. ckpt_index = i * ckpt_mp_size + sd_offset
  292. ckpt_files = [
  293. os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
  294. for j in range(sd_count)
  295. ]
  296. sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
  297. load_model_with_checkpoint(replaced_module,
  298. sds,
  299. mp_replace,
  300. ckpt_type,
  301. ckpt_mp_size,
  302. quantizer,
  303. int(rank % tp_split_size),
  304. container=container_g)
  305. sds = [None for _ in sds]
  306. gc.collect()
  307. if "non_tp" in checkpoint:
  308. pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
  309. desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
  310. for i in range(len(checkpoint["non_tp"])):
  311. pbar.update(1)
  312. ckpt_file = os.path.join(base_dir1,
  313. checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
  314. sds = [torch.load(ckpt_file, map_location='cpu')]
  315. load_model_with_checkpoint(replaced_module,
  316. sds,
  317. mp_replace,
  318. ckpt_type,
  319. ckpt_mp_size,
  320. quantizer,
  321. int(rank % tp_split_size),
  322. container=container_g)
  323. sds = [None for _ in sds]
  324. gc.collect()
  325. print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
  326. if config.save_mp_checkpoint_path is not None:
  327. from collections import OrderedDict
  328. import json
  329. num_partitions = 8
  330. if checkpoint_dict is None:
  331. ckpt_name = "ds_model"
  332. try:
  333. from transformers.models.bloom.modeling_bloom import BloomForCausalLM
  334. if isinstance(model, BloomForCausalLM):
  335. ckpt_name = "bloom"
  336. except ImportError:
  337. ckpt_name = "ds_model"
  338. else:
  339. ckpt_name = checkpoint_dict['type']
  340. if dist.is_initialized():
  341. dist.barrier()
  342. transformer_name = get_transformer_name(replaced_module)
  343. non_tp_ckpt_name = f'non-tp.pt'
  344. ckpt_files = [non_tp_ckpt_name]
  345. os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
  346. if not dist.is_initialized() or dist.get_rank() == 0:
  347. print("Saving tp-sharded checkpoints")
  348. torch.save(
  349. OrderedDict({k: v
  350. for k, v in dict(replaced_module.state_dict()).items()
  351. if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
  352. dtype_reprs = {
  353. torch.float32: 'float32',
  354. torch.float16: 'float16',
  355. torch.int8: 'int8',
  356. torch.bfloat16: 'bfloat16'
  357. }
  358. ckpt_config = json.dumps({
  359. 'type': ckpt_name,
  360. 'base_dir': f'{config.save_mp_checkpoint_path}',
  361. 'checkpoints': {
  362. "non_tp": ckpt_files,
  363. "tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
  364. },
  365. 'version': 1.0,
  366. 'parallelization': 'tp',
  367. 'tp_size': world_size,
  368. 'dtype': dtype_reprs[config.dtype]
  369. })
  370. with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
  371. cfg.write(ckpt_config)
  372. rep_sd = replaced_module.state_dict()
  373. for n, p in replaced_module.named_parameters():
  374. if hasattr(p, 'scale'):
  375. rep_sd[n] = [p, p.scale]
  376. keys = list(rep_sd.keys())
  377. partition_size = (len(keys) // num_partitions + 1)
  378. for m in range(num_partitions):
  379. torch.save(
  380. OrderedDict({
  381. k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
  382. for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
  383. }), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
  384. return replaced_module
  385. def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
  386. """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer
  387. Arguments:
  388. orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
  389. e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
  390. model (torch.nn.Module): user's nn.module representing their model
  391. config (dict): model config containing hidden size, attention heads, etc.
  392. Returns:
  393. Updated nn.module with original bert-style transformer layers
  394. """
  395. def replace_fn(child, _replace_policy, layer_id):
  396. #from turing.nvidia_modelingpreln import BertLayer
  397. orig_module = orig_layer_impl(config)
  398. # copy relevant state from child -> original module
  399. qkvw = child.attn_qkvw.data
  400. qkvb = child.attn_qkvb.data
  401. qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
  402. qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
  403. orig_module.attention.self.query.weight.data = qw
  404. orig_module.attention.self.query.bias.data = qb
  405. orig_module.attention.self.key.weight.data = kw
  406. orig_module.attention.self.key.bias.data = kb
  407. orig_module.attention.self.value.weight.data = vw
  408. orig_module.attention.self.value.bias.data = vb
  409. orig_module.attention.output.dense.weight.data = child.attn_ow.data
  410. orig_module.attention.output.dense.bias.data = child.attn_ob.data
  411. attn_ln_w = child.attn_nw.data
  412. attn_ln_b = child.attn_nb.data
  413. if preln:
  414. orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
  415. orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
  416. else:
  417. orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
  418. orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
  419. inter_ff_w = child.inter_w.data
  420. inter_ff_b = child.inter_b.data
  421. if preln:
  422. orig_module.intermediate.dense_act.weight.data = inter_ff_w
  423. orig_module.intermediate.dense_act.bias.data = inter_ff_b
  424. else:
  425. orig_module.intermediate.dense.weight.data = inter_ff_w
  426. orig_module.intermediate.dense.bias.data = inter_ff_b
  427. orig_module.output.dense.weight.data = child.output_w.data
  428. orig_module.output.dense.bias.data = child.output_b.data
  429. transformer_ln_w = child.norm_w.data
  430. transformer_ln_b = child.norm_b.data
  431. if preln:
  432. orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
  433. orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
  434. else:
  435. orig_module.output.LayerNorm.weight.data = transformer_ln_w
  436. orig_module.output.LayerNorm.bias.data = transformer_ln_b
  437. return orig_module
  438. return replace_module(model=model,
  439. orig_class=deepspeed.DeepSpeedTransformerLayer,
  440. replace_fn=replace_fn,
  441. _replace_policy=None)
  442. def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None):
  443. """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
  444. Arguments:
  445. model (torch.nn.Module): the model to augment
  446. orig_class (torch.nn.Module): the module to search for
  447. replace_fn (method): a method to convert instances of ``orig_class`` to the
  448. desired type and return a new instance.
  449. Returns:
  450. A modified ``model``.
  451. """
  452. sd = None
  453. if checkpoint is not None:
  454. sd = torch.load(checkpoint, map_location='cpu')
  455. policy = {}
  456. if orig_class is not None:
  457. policy.update({orig_class: (replace_fn, _replace_policy)})
  458. else:
  459. for plcy in replace_policies:
  460. # instantiate a throw-away policy in order to populate the _orig_layer_class
  461. _ = plcy(None)
  462. if isinstance(plcy._orig_layer_class, list):
  463. for orig_layer_class in plcy._orig_layer_class:
  464. policy.update({orig_layer_class: (replace_fn, plcy)})
  465. elif plcy._orig_layer_class is not None:
  466. policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
  467. assert len(policy.items()) > 0,\
  468. "No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
  469. "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
  470. replaced_module, _ = _replace_module(model, policy, state_dict=sd)
  471. if checkpoint is not None:
  472. embedding_weight = None
  473. for n, p in replaced_module.named_parameters():
  474. if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
  475. embedding_weight = p
  476. if embedding_weight is not None and hasattr(replaced_module, "lm_head") and hasattr(
  477. replaced_module.lm_head, "weight") and replaced_module.lm_head.weight.is_meta:
  478. replaced_module.lm_head.weight = embedding_weight
  479. return replaced_module
  480. from ..pipe import PipelineModule
  481. import re
  482. def skip_level_0_prefix(model, state_dict):
  483. model = str(model)
  484. key = re.search(r": (.*?)Model", model)
  485. if key is None:
  486. key = re.search(r": (.*?)Stack", model)
  487. if key is None:
  488. key = re.match(r"(.*?)Model", model)
  489. # if keys start with 'model.', don't skip level 0 prefix
  490. if state_dict != None:
  491. for item in state_dict.keys():
  492. if re.match("^model[.]", item):
  493. return False
  494. if key is not None and key.group(1).lower() in ["bloom", "opt"]:
  495. return True
  496. return False
  497. def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None):
  498. """ Traverse model's children recursively and apply any transformations in ``policies``.
  499. Arguments:
  500. model (torch.nn.Module): model to augment
  501. policies (dict): Mapping of source class to replacement function.
  502. Returns:
  503. Modified ``model``.
  504. """
  505. for name, child in model.named_children():
  506. if child.__class__ in policies:
  507. replaced_module = policies[child.__class__][0](child,
  508. policies[child.__class__][-1],
  509. layer_id,
  510. prefix=prefix + name,
  511. state_dict=state_dict)
  512. setattr(model, name, replaced_module)
  513. if isinstance(model, PipelineModule):
  514. assert hasattr(model, 'forward_funcs'),\
  515. "we require pipe-module to have the list of fwd_functions"
  516. model.forward_funcs[model.fwd_map[name]] = replaced_module
  517. layer_id += 1
  518. else:
  519. checking_key = prefix + name + '.'
  520. if Loading.is_load_module(child) and state_dict is not None:
  521. if any(checking_key in item for item in state_dict):
  522. Loading.load(
  523. child,
  524. state_dict,
  525. checking_key,
  526. )
  527. else:
  528. continue
  529. if len(child._buffers) != 0 and state_dict is not None:
  530. Loading.load_buffer(child, state_dict, checking_key)
  531. _, layer_id = _replace_module(child,
  532. policies,
  533. prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \
  534. prefix + name + '.',
  535. layer_id=layer_id,
  536. level_id=level_id + 1,
  537. state_dict=state_dict)
  538. # Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
  539. model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
  540. return model, layer_id