replace_module.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import tqdm
  7. import deepspeed
  8. import deepspeed.ops.transformer as transformer_inference
  9. from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
  10. from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
  11. from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
  12. from deepspeed.accelerator import get_accelerator
  13. from .replace_policy import replace_policies, generic_policies
  14. from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
  15. from deepspeed import comm as dist
  16. from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd
  17. from .load_checkpoint import load_model_with_checkpoint
  18. import time
  19. from .utils import policy_to_ds_container
  20. import gc
  21. def get_transformer_name(replaced_module):
  22. from .containers import supported_models
  23. from torch.nn import ModuleList
  24. transformer_name = ''
  25. for n, c in replaced_module.named_children():
  26. if c.__class__ in supported_models:
  27. transformer_name += n + '.'
  28. for name, child in c.named_children():
  29. if child.__class__ is ModuleList:
  30. transformer_name += name
  31. break
  32. break
  33. return transformer_name
  34. class GroupQuantizer:
  35. def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
  36. self.group_size = group_size
  37. self.num_bits = num_bits
  38. self.q_int8 = q_int8
  39. self.num_groups = num_groups
  40. def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
  41. if not self.q_int8 or not qkv:
  42. inputs = torch.nn.Parameter(inputs, requires_grad=False)
  43. inputs.scale = torch.empty(1)
  44. return inputs
  45. q_range = 2**self.num_bits
  46. num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
  47. inputs = inputs.to(get_accelerator().current_device_name())
  48. input_flat = inputs.reshape(num_groups, -1).contiguous()
  49. input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
  50. input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
  51. scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
  52. input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
  53. inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
  54. out = torch.nn.Parameter(inputs_q, requires_grad=False)
  55. inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
  56. input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
  57. input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
  58. input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
  59. scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
  60. for i in range(2)]
  61. out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
  62. -1).contiguous()
  63. return out
  64. def _module_match(module):
  65. for policy in generic_policies:
  66. policy = policy()
  67. if policy.match(module):
  68. return policy
  69. return None
  70. def generic_injection(module, dtype=None, enable_cuda_graph=True):
  71. def replace_attn(child, policy):
  72. policy_attn = policy.attention(child)
  73. if policy_attn is None:
  74. return child
  75. if len(policy_attn) == 5:
  76. qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
  77. else:
  78. qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn
  79. config = transformer_inference.DeepSpeedInferenceConfig(
  80. hidden_size=hidden_size,
  81. heads=heads,
  82. dtype=dtype,
  83. triangular_masking=False,
  84. max_out_tokens=4096,
  85. )
  86. attn_module = DeepSpeedDiffusersAttention(config)
  87. def transpose(data):
  88. data = data.contiguous()
  89. data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
  90. data = data.reshape(data.shape[-1], data.shape[-2])
  91. data.to(get_accelerator().current_device_name())
  92. return data
  93. if len(policy_attn) == 5:
  94. attn_module.attn_qkvw.data = transpose(qkvw.data)
  95. else:
  96. attn_module.attn_qkvw = None
  97. attn_module.attn_qw.data = transpose(qw.data)
  98. attn_module.attn_kw.data = transpose(kw.data)
  99. attn_module.attn_vw.data = transpose(vw.data)
  100. attn_module.attn_qkvb = None
  101. attn_module.attn_ow.data = transpose(attn_ow.data)
  102. attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
  103. return attn_module
  104. def replace_attn_block(child, policy):
  105. config = Diffusers2DTransformerConfig()
  106. return DeepSpeedDiffusersTransformerBlock(child, config)
  107. if isinstance(module, torch.nn.Module):
  108. pass
  109. else:
  110. if dtype not in [torch.float16, torch.half]:
  111. raise ValueError("Generic injection only supported with FP16")
  112. try:
  113. import diffusers
  114. if hasattr(diffusers.models.attention, 'CrossAttention'):
  115. cross_attention = diffusers.models.attention.CrossAttention
  116. else:
  117. cross_attention = diffusers.models.attention_processor.Attention
  118. attention_block = diffusers.models.attention.BasicTransformerBlock
  119. new_policies = {
  120. cross_attention: replace_attn,
  121. attention_block: replace_attn_block,
  122. }
  123. except ImportError:
  124. new_policies = {}
  125. #replace_transformer_layer(None,
  126. # module.text_encoder,
  127. # training=False,
  128. # replace_with_kernel_inject=True,
  129. # triangular_masking=True,
  130. # max_out_tokens=8192)
  131. from ..model_implementations.transformers.clip_encoder import DSClipEncoder
  132. cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
  133. setattr(module, 'text_encoder', cg_encoder)
  134. for name in module.__dict__.keys():
  135. sub_module = getattr(module, name)
  136. policy = _module_match(sub_module)
  137. if policy is not None:
  138. def _replace_module(module, policy):
  139. for name, child in module.named_children():
  140. _replace_module(child, policy)
  141. if child.__class__ in new_policies:
  142. replaced_module = new_policies[child.__class__](child, policy)
  143. setattr(module, name, replaced_module)
  144. _replace_module(sub_module, policy)
  145. new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
  146. print(f"**** found and replaced {name} w. {type(new_module)}")
  147. setattr(module, name, new_module)
  148. container_g = None
  149. def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
  150. """ Replace bert-style transformer layers with DeepSpeed's transformer layer
  151. Arguments:
  152. orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
  153. e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
  154. model (torch.nn.Module): user's nn.module representing their model
  155. checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine
  156. config: top-level DS Inference config defined in inference/config.py
  157. model_config: HuggingFace model config passed from the inference/engine.py
  158. Returns:
  159. Updated nn.module with replaced transformer layers
  160. """
  161. # defining globals as internally defined functions inherit these everywhere
  162. quantize = (config.dtype == torch.int8)
  163. # todo: Refactor later. In future, let's minimize the style used above and use config.** instead
  164. linear_layer_setting = None
  165. '''
  166. linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers
  167. '''
  168. micro_batch_size = -1
  169. seed = -1
  170. local_rank = -1
  171. mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
  172. mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
  173. def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
  174. policy = policy_cls(child, inference=inference)
  175. if not policy.cuda_graph_supported:
  176. # policy says cuda graph is not supported raise an error if set
  177. assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
  178. from deepspeed.moe.layer import MoE
  179. moe = False
  180. if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
  181. num_experts = child.mlp.num_experts
  182. moe = True
  183. # 1. Create a model-specific container object using the policy object.
  184. _container = policy_to_ds_container(policy=policy,
  185. config=config,
  186. model_config=model_config,
  187. layer_id=layer_id,
  188. child=child)
  189. _container.set_moe(moe)
  190. # 2. Set the tensor parallelism config
  191. _container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
  192. # 3. Initialize tensors
  193. _container.initialize_tensors()
  194. # 4. deal with data types -- needs refactor to use dtype instead of fp16
  195. if config.dtype in [torch.float16, torch.bfloat16, torch.int8]:
  196. _container.convert_to_required_dtype()
  197. # 5. Set the quantization config
  198. quantizer = GroupQuantizer(q_int8=quantize)
  199. _container.set_quantization_config(quantizer)
  200. # 6. create a DS Inference config object
  201. _container.create_ds_model_config()
  202. # 7. use the config and create the module
  203. _container.create_module()
  204. # 8. transpose the weights and bias if needed
  205. _container.transpose()
  206. # 9. deal with tensor parallelism.
  207. _container.apply_tensor_parallelism(mp_replace)
  208. # 10. copy the tensors from the model-specific container to the new module
  209. _container.copy_data_to_new_module()
  210. # 11. set global for generic checkpoint loading
  211. global container_g
  212. if container_g is None:
  213. container_g = _container
  214. return _container.module
  215. def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
  216. #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
  217. # 1. Create AutoTP object
  218. _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
  219. # 2. Set the tensor parallelism config
  220. _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
  221. # 3. Try to get num_key_heads from model_config.num_key_value_heads
  222. num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
  223. # 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
  224. set_num_kv_heads(num_kv_heads)
  225. # 4.1 Get n_embd
  226. n_embd = None
  227. multi_query_n_embd_names = ['n_embd']
  228. for name in multi_query_n_embd_names:
  229. if hasattr(model_config, name):
  230. n_embd = getattr(model_config, name)
  231. if n_embd != None:
  232. break
  233. # 4.2 set n_embd
  234. set_n_embd(n_embd)
  235. # 5. Set linear policies
  236. _autotp.update_linear_policies()
  237. # 6. Replace modules
  238. if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
  239. return _autotp._replace_last_linear_module(module)
  240. return _autotp._replace_module(module)
  241. def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
  242. training = False # todo: refactor this part to go in the config
  243. if training:
  244. # copy relevant state from child -> new module
  245. new_module = replace_with_policy(child, _policy, config.triangular_masking)
  246. else:
  247. # copy relevant state from child -> new module
  248. if config.replace_with_kernel_inject:
  249. new_module = replace_with_policy(child,
  250. _policy,
  251. config.triangular_masking,
  252. inference=True,
  253. layer_id=layer_id)
  254. else:
  255. new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
  256. return new_module
  257. def set_lm_head(module):
  258. embedding_weight = None
  259. for n, p in module.named_parameters():
  260. if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
  261. embedding_weight = p
  262. if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
  263. module.lm_head, "weight") and module.lm_head.weight.is_meta:
  264. module.lm_head.weight = embedding_weight
  265. # enable tensor parallel for the last linear
  266. if hasattr(module, "lm_head") and hasattr(module.lm_head,
  267. "weight") and not module.lm_head.weight.is_meta and isinstance(
  268. module.lm_head, torch.nn.Linear):
  269. module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
  270. elif hasattr(module, "embed_out") and hasattr(module.embed_out,
  271. "weight") and not module.embed_out.weight.is_meta and isinstance(
  272. module.embed_out, torch.nn.Linear):
  273. module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
  274. return module
  275. if checkpoint_dict is not None and not config.replace_with_kernel_inject:
  276. # AutoTP shard loading
  277. checkpoint = checkpoint_dict["checkpoints"]
  278. pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
  279. for i in range(len(checkpoint)):
  280. checkpoint_file = os.path.join(config.base_dir, checkpoint[i])
  281. replaced_module = replace_module(model=model,
  282. orig_class=orig_layer_impl,
  283. replace_fn=replace_fn,
  284. _replace_policy=config.injection_policy_tuple,
  285. checkpoint=checkpoint_file)
  286. pbar.update(1)
  287. gc.collect()
  288. replaced_module = set_lm_head(replaced_module)
  289. else:
  290. replaced_module = replace_module(model=model,
  291. orig_class=orig_layer_impl,
  292. replace_fn=replace_fn,
  293. _replace_policy=config.injection_policy_tuple)
  294. quantizer = GroupQuantizer(q_int8=quantize)
  295. world_size = dist.get_world_size() if dist.is_initialized() else 1
  296. rank = dist.get_rank() if dist.is_initialized() else 0
  297. if checkpoint_dict is not None and config.replace_with_kernel_inject:
  298. assert container_g.ckpt_load_enabled, \
  299. f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
  300. start_time = time.time()
  301. checkpoint = checkpoint_dict['checkpoints']
  302. ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint
  303. ckpt_type = checkpoint_dict.get('parallelization', 'pp')
  304. ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list))
  305. ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size)
  306. base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
  307. if ckpt_type == 'pp' and type(checkpoint) is list:
  308. pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
  309. for i in range(len(checkpoint)):
  310. sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
  311. load_model_with_checkpoint(replaced_module,
  312. sd,
  313. mp_replace,
  314. ckpt_type,
  315. ckpt_mp_size,
  316. quantizer,
  317. container=container_g)
  318. pbar.update(1)
  319. else:
  320. num_checkpoints = len(ckpt_list) // ckpt_mp_size
  321. tp_split_size = (world_size / ckpt_mp_size)
  322. sd_offset = int(rank / tp_split_size)
  323. sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
  324. pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
  325. for i in range(num_checkpoints):
  326. pbar.update(1)
  327. ckpt_index = i * ckpt_mp_size + sd_offset
  328. ckpt_files = [
  329. os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
  330. for j in range(sd_count)
  331. ]
  332. sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
  333. load_model_with_checkpoint(replaced_module,
  334. sds,
  335. mp_replace,
  336. ckpt_type,
  337. ckpt_mp_size,
  338. quantizer,
  339. int(rank % tp_split_size),
  340. container=container_g)
  341. sds = [None for _ in sds]
  342. gc.collect()
  343. if "non_tp" in checkpoint:
  344. pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
  345. desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
  346. for i in range(len(checkpoint["non_tp"])):
  347. pbar.update(1)
  348. ckpt_file = os.path.join(base_dir1,
  349. checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
  350. sds = [torch.load(ckpt_file, map_location='cpu')]
  351. load_model_with_checkpoint(replaced_module,
  352. sds,
  353. mp_replace,
  354. ckpt_type,
  355. ckpt_mp_size,
  356. quantizer,
  357. int(rank % tp_split_size),
  358. container=container_g)
  359. sds = [None for _ in sds]
  360. gc.collect()
  361. set_lm_head(replaced_module)
  362. print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
  363. if config.save_mp_checkpoint_path is not None:
  364. from collections import OrderedDict
  365. import json
  366. num_partitions = 8
  367. if checkpoint_dict is None:
  368. ckpt_name = "ds_model"
  369. try:
  370. from transformers.models.bloom.modeling_bloom import BloomForCausalLM
  371. if isinstance(model, BloomForCausalLM):
  372. ckpt_name = "bloom"
  373. except ImportError:
  374. ckpt_name = "ds_model"
  375. else:
  376. ckpt_name = checkpoint_dict['type']
  377. if dist.is_initialized():
  378. dist.barrier()
  379. transformer_name = get_transformer_name(replaced_module)
  380. non_tp_ckpt_name = f'non-tp.pt'
  381. ckpt_files = [non_tp_ckpt_name]
  382. os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
  383. if not dist.is_initialized() or dist.get_rank() == 0:
  384. print("Saving tp-sharded checkpoints")
  385. torch.save(
  386. OrderedDict({k: v
  387. for k, v in dict(replaced_module.state_dict()).items()
  388. if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
  389. dtype_reprs = {
  390. torch.float32: 'float32',
  391. torch.float16: 'float16',
  392. torch.int8: 'int8',
  393. torch.bfloat16: 'bfloat16'
  394. }
  395. ckpt_config = json.dumps({
  396. 'type': ckpt_name,
  397. 'base_dir': f'{config.save_mp_checkpoint_path}',
  398. 'checkpoints': {
  399. "non_tp": ckpt_files,
  400. "tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
  401. },
  402. 'version': 1.0,
  403. 'parallelization': 'tp',
  404. 'tp_size': world_size,
  405. 'dtype': dtype_reprs[config.dtype]
  406. })
  407. with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
  408. cfg.write(ckpt_config)
  409. rep_sd = replaced_module.state_dict()
  410. for n, p in replaced_module.named_parameters():
  411. if hasattr(p, 'scale'):
  412. rep_sd[n] = [p, p.scale]
  413. keys = list(rep_sd.keys())
  414. partition_size = (len(keys) // num_partitions + 1)
  415. for m in range(num_partitions):
  416. torch.save(
  417. OrderedDict({
  418. k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
  419. for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
  420. }), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
  421. return replaced_module
  422. def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
  423. """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer
  424. Arguments:
  425. orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
  426. e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
  427. model (torch.nn.Module): user's nn.module representing their model
  428. config (dict): model config containing hidden size, attention heads, etc.
  429. Returns:
  430. Updated nn.module with original bert-style transformer layers
  431. """
  432. def replace_fn(child, _replace_policy, layer_id):
  433. #from turing.nvidia_modelingpreln import BertLayer
  434. orig_module = orig_layer_impl(config)
  435. # copy relevant state from child -> original module
  436. qkvw = child.attn_qkvw.data
  437. qkvb = child.attn_qkvb.data
  438. qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
  439. qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
  440. orig_module.attention.self.query.weight.data = qw
  441. orig_module.attention.self.query.bias.data = qb
  442. orig_module.attention.self.key.weight.data = kw
  443. orig_module.attention.self.key.bias.data = kb
  444. orig_module.attention.self.value.weight.data = vw
  445. orig_module.attention.self.value.bias.data = vb
  446. orig_module.attention.output.dense.weight.data = child.attn_ow.data
  447. orig_module.attention.output.dense.bias.data = child.attn_ob.data
  448. attn_ln_w = child.attn_nw.data
  449. attn_ln_b = child.attn_nb.data
  450. if preln:
  451. orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
  452. orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
  453. else:
  454. orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
  455. orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
  456. inter_ff_w = child.inter_w.data
  457. inter_ff_b = child.inter_b.data
  458. if preln:
  459. orig_module.intermediate.dense_act.weight.data = inter_ff_w
  460. orig_module.intermediate.dense_act.bias.data = inter_ff_b
  461. else:
  462. orig_module.intermediate.dense.weight.data = inter_ff_w
  463. orig_module.intermediate.dense.bias.data = inter_ff_b
  464. orig_module.output.dense.weight.data = child.output_w.data
  465. orig_module.output.dense.bias.data = child.output_b.data
  466. transformer_ln_w = child.norm_w.data
  467. transformer_ln_b = child.norm_b.data
  468. if preln:
  469. orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
  470. orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
  471. else:
  472. orig_module.output.LayerNorm.weight.data = transformer_ln_w
  473. orig_module.output.LayerNorm.bias.data = transformer_ln_b
  474. return orig_module
  475. return replace_module(model=model,
  476. orig_class=deepspeed.DeepSpeedTransformerLayer,
  477. replace_fn=replace_fn,
  478. _replace_policy=None)
  479. def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None):
  480. """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
  481. Arguments:
  482. model (torch.nn.Module): the model to augment
  483. orig_class (torch.nn.Module): the module to search for
  484. replace_fn (method): a method to convert instances of ``orig_class`` to the
  485. desired type and return a new instance.
  486. Returns:
  487. A modified ``model``.
  488. """
  489. sd = None
  490. if checkpoint is not None:
  491. if checkpoint.endswith(".safetensors"):
  492. from safetensors.torch import load_file
  493. sd = load_file(checkpoint)
  494. else:
  495. sd = torch.load(checkpoint, map_location='cpu')
  496. policy = {}
  497. if orig_class is not None:
  498. policy.update({orig_class: (replace_fn, _replace_policy)})
  499. else:
  500. for plcy in replace_policies:
  501. # instantiate a throw-away policy in order to populate the _orig_layer_class
  502. _ = plcy(None)
  503. if isinstance(plcy._orig_layer_class, list):
  504. for orig_layer_class in plcy._orig_layer_class:
  505. policy.update({orig_layer_class: (replace_fn, plcy)})
  506. elif plcy._orig_layer_class is not None:
  507. policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
  508. assert len(policy.items()) > 0,\
  509. "No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
  510. "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
  511. replaced_module, _ = _replace_module(model, policy, state_dict=sd)
  512. return replaced_module
  513. from ..pipe import PipelineModule
  514. import re
  515. def skip_level_0_prefix(model, state_dict):
  516. model = str(model)
  517. key = re.search(r": (.*?)Model", model)
  518. if key is None:
  519. key = re.search(r": (.*?)Stack", model)
  520. if key is None:
  521. key = re.match(r"(.*?)Model", model)
  522. # if keys start with 'model.', don't skip level 0 prefix
  523. if state_dict is not None:
  524. for item in state_dict.keys():
  525. if re.match("^model[.]", item):
  526. return False
  527. if key is not None and key.group(1).lower() in ["bloom", "opt"]:
  528. return True
  529. return False
  530. def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None):
  531. """ Traverse model's children recursively and apply any transformations in ``policies``.
  532. Arguments:
  533. model (torch.nn.Module): model to augment
  534. policies (dict): Mapping of source class to replacement function.
  535. Returns:
  536. Modified ``model``.
  537. """
  538. for name, child in model.named_children():
  539. if child.__class__ in policies:
  540. replaced_module = policies[child.__class__][0](child,
  541. policies[child.__class__][-1],
  542. layer_id,
  543. prefix=prefix + name,
  544. state_dict=state_dict)
  545. setattr(model, name, replaced_module)
  546. if isinstance(model, PipelineModule):
  547. assert hasattr(model, 'forward_funcs'),\
  548. "we require pipe-module to have the list of fwd_functions"
  549. model.forward_funcs[model.fwd_map[name]] = replaced_module
  550. layer_id += 1
  551. else:
  552. checking_key = prefix + name + '.'
  553. if Loading.is_load_module(child) and state_dict is not None:
  554. if any(checking_key in item for item in state_dict):
  555. Loading.load(
  556. child,
  557. state_dict,
  558. checking_key,
  559. )
  560. else:
  561. continue
  562. if len(child._buffers) != 0 and state_dict is not None:
  563. Loading.load_buffer(child, state_dict, checking_key)
  564. _, layer_id = _replace_module(child,
  565. policies,
  566. prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \
  567. prefix + name + '.',
  568. layer_id=layer_id,
  569. level_id=level_id + 1,
  570. state_dict=state_dict)
  571. # Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
  572. model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
  573. return model, layer_id