load_checkpoint.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from torch import nn
  3. from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
  4. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  5. from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
  6. from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference
  7. from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference
  8. import deepspeed.ops.transformer as transformer_inference
  9. from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
  10. import torch
  11. import gc
  12. from deepspeed.accelerator import get_accelerator
  13. import re
  14. def load_model_with_checkpoint(r_module,
  15. sd,
  16. mp_replace,
  17. ckpt_type,
  18. ckpt_mp_size,
  19. weight_quantizer=None,
  20. rank=0,
  21. container=None):
  22. error_msgs = []
  23. def prefix_check():
  24. # if keys start with 'model.', don't skip level 0 prefix
  25. for key in sd[0].keys():
  26. if re.match("^model[.]", key):
  27. return False
  28. return True
  29. skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix
  30. def transpose(data):
  31. with torch.no_grad():
  32. data = data.contiguous()
  33. data1 = data.transpose(-1, -2).reshape(-1)
  34. data.reshape(-1).copy_(data1)
  35. data1 = None
  36. return data.reshape(data.shape[-1], data.shape[-2])
  37. def load(module, prefix):
  38. args = (sd[0], prefix, {}, True, [], [], error_msgs)
  39. if hasattr(module, 'weight'):
  40. module.weight = mp_replace.copy(module.weight.data, sd[0][prefix + 'weight'])
  41. if prefix + 'bias' in sd[0].keys():
  42. if module.bias.data.is_meta:
  43. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  44. module.bias = torch.nn.parameter.Parameter(
  45. data=torch.empty_like(module.bias.data,
  46. device="cpu"),
  47. requires_grad=module.bias.data.requires_grad)
  48. module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
  49. args = None
  50. gc.collect()
  51. def load_transformer_layer(module, prefix):
  52. if ckpt_type == "tp":
  53. def load_parameters(module, prefix):
  54. for n, p in module.named_parameters():
  55. if prefix + n in sd[0] and len(n.split('.')) == 1:
  56. if type(sd[0][prefix + n]) is list:
  57. tmp_data, scale = sd[0][prefix + n]
  58. tmp_data = tmp_data
  59. scale = scale.to(get_accelerator().current_device_name())
  60. # set the quantizer number of groups using the checkpoint scale shape
  61. weight_quantizer.num_groups = scale.shape[0]
  62. else:
  63. tmp_data = sd[0][prefix + n].to(
  64. get_accelerator().current_device_name())
  65. scale = None
  66. src_shape = tmp_data.shape
  67. dst_shape = p.shape
  68. inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
  69. outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
  70. if (len(src_shape) == 2 and len(dst_shape) == 2):
  71. if (src_shape[inner_dim] == dst_shape[0]
  72. and src_shape[outer_dim] == dst_shape[1]):
  73. if tmp_data.dtype != torch.int8:
  74. p = weight_quantizer.quantize(
  75. transpose(tmp_data) if weight_quantizer.
  76. q_int8 else tmp_data)
  77. else:
  78. p = torch.nn.parameter.Parameter(tmp_data,
  79. requires_grad=False)
  80. p.scale = scale
  81. setattr(module, n, p)
  82. else:
  83. dim = inner_dim if src_shape[inner_dim] != dst_shape[
  84. 0] else outer_dim
  85. dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
  86. if src_shape[dim] > dst_shape[dim1]:
  87. weight_partition = torch.split(
  88. tmp_data,
  89. dst_shape[dim1],
  90. dim=dim)[rank].to(
  91. get_accelerator().current_device_name())
  92. assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
  93. '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
  94. Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
  95. scale = scale.view(
  96. -1)[weight_quantizer.num_groups *
  97. (rank + 1):].reshape(
  98. weight_quantizer.num_groups,
  99. -1).contiguous()
  100. else:
  101. assert tmp_data.dtype != torch.int8, \
  102. '''Merging of the checkpoints are not supported when using INT8 checkpoint! \
  103. Please use a as many GPUs as TP-size for the checkpoint'''
  104. all_data = [
  105. sd[j][prefix +
  106. n] if type(sd[j][prefix + n]) is list else
  107. sd[j][prefix + n].to(
  108. get_accelerator().current_device_name())
  109. for j in range(len(sd))
  110. ]
  111. # Check if the weight tensor is for the QKV parameter
  112. if src_shape[1] == (3 *
  113. src_shape[0]) // ckpt_mp_size:
  114. qkv_size = src_shape[outer_dim] // 3
  115. src_split = [
  116. torch.split(src[0].data,
  117. qkv_size,
  118. dim=outer_dim)
  119. for src in all_data
  120. ]
  121. weight_partition = torch.cat([
  122. torch.cat([qkv_s[i] for qkv_s in src_split],
  123. axis=outer_dim)
  124. for i in range(len(src_split[0]))
  125. ],
  126. dim=dim)
  127. else:
  128. weight_partition = torch.cat([
  129. ad[0].to(
  130. get_accelerator().current_device_name())
  131. if type(ad) is list else ad
  132. for ad in all_data
  133. ],
  134. dim=dim)
  135. if tmp_data.dtype == torch.int8:
  136. scale = torch.cat([
  137. ad[1].to(
  138. get_accelerator().current_device_name())
  139. for ad in all_data
  140. ],
  141. dim=dim)
  142. if tmp_data.dtype != torch.int8:
  143. weight_partition = weight_quantizer.quantize(
  144. transpose(weight_partition), \
  145. parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
  146. weight_quantizer.quantize(weight_partition)
  147. else:
  148. weight_partition = torch.nn.parameter.Parameter(
  149. weight_partition,
  150. requires_grad=False)
  151. weight_partition.scale = scale
  152. setattr(module, n, weight_partition)
  153. else:
  154. if src_shape[0] == dst_shape[0]:
  155. p.data.copy_(tmp_data)
  156. else:
  157. if src_shape[0] > dst_shape[0]:
  158. bias_split = torch.split(
  159. tmp_data,
  160. dst_shape[-1])[rank].to(get_accelerator(
  161. ).current_device_name()).contiguous()
  162. p.data.copy_(bias_split)
  163. else:
  164. # Check if the weight tensor is for the QKV parameter
  165. if src_shape[0] == (3 * r_module.config.hidden_size
  166. ) // ckpt_mp_size:
  167. qkv_size = src_shape[0] // 3
  168. src_split = [
  169. torch.split(sd[j][prefix + n],
  170. qkv_size,
  171. dim=0) for j in range(len(sd))
  172. ]
  173. p.data.copy_(
  174. torch.cat(
  175. [
  176. torch.cat([
  177. qkv_s[i] for qkv_s in src_split
  178. ],
  179. axis=0)
  180. for i in range(len(src_split[0]))
  181. ],
  182. dim=0).to(get_accelerator(
  183. ).current_device_name()).contiguous())
  184. else:
  185. p.data.copy_(
  186. torch.cat(
  187. [
  188. sd[j][prefix + n]
  189. for j in range(len(sd))
  190. ],
  191. dim=0).to(get_accelerator(
  192. ).current_device_name()).contiguous())
  193. load_parameters(module, prefix)
  194. for n, child in module.named_children():
  195. load_parameters(child, prefix + n + '.')
  196. else:
  197. container.load_params(module, sd[0], weight_quantizer, mp_replace, prefix)
  198. try:
  199. import transformers
  200. OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
  201. except:
  202. OPTLearnedPositionalEmbedding = None
  203. layer_policies = {
  204. nn.Linear: load,
  205. nn.Embedding: load,
  206. nn.LayerNorm: load,
  207. EmbeddingLayer: load,
  208. LinearLayer: load,
  209. Normalize: load,
  210. transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
  211. DeepSpeedBloomInference: load_transformer_layer,
  212. DeepSpeedGPTInference: load_transformer_layer,
  213. DeepSpeedBERTInference: load_transformer_layer,
  214. DeepSpeedMegatronGPTInference: load_transformer_layer,
  215. DeepSpeedOPTInference: load_transformer_layer,
  216. OPTLearnedPositionalEmbedding: load,
  217. OPTEmbedding: load
  218. }
  219. all_ds_ids = {}
  220. def load_module_recursive(module, prefix='', level=0):
  221. for name, child in module.named_children():
  222. if child.__class__ in layer_policies:
  223. checking_key = prefix + name + '.'
  224. if not any(checking_key in item for item in sd[0].keys()):
  225. if hasattr(child, 'weight') and \
  226. (hasattr(child.weight, 'ds_id') and \
  227. child.weight.ds_id in all_ds_ids):
  228. prefix1 = all_ds_ids[child.weight.ds_id]
  229. if child.__class__ is nn.Linear:
  230. child = LinearLayer(weight=all_ds_ids[child.weight.ds_id])
  231. setattr(module, name, child)
  232. continue
  233. child_params = list(child.parameters())
  234. if len(child_params) > 0 and (child_params[0].numel() == 0
  235. or child_params[0].is_meta):
  236. if child.weight.is_meta:
  237. ds_shape = child.weight.shape
  238. else:
  239. ds_shape = child.weight.ds_shape
  240. if child.__class__ is nn.LayerNorm:
  241. child = Normalize(dim=ds_shape[-1],
  242. dtype=child.weight.dtype,
  243. eps=child.eps)
  244. setattr(module, name, child)
  245. elif child.__class__ is nn.Linear:
  246. child = LinearLayer(weight_shape=child.weight.shape,
  247. bias=child.bias)
  248. setattr(module, name, child)
  249. elif child.__class__ is OPTLearnedPositionalEmbedding:
  250. child = OPTEmbedding(weight_shape=ds_shape)
  251. setattr(module, name, child)
  252. else:
  253. ds_id = None
  254. if hasattr(child.weight, 'ds_id'):
  255. ds_id = child.weight.ds_id
  256. child = EmbeddingLayer(weight_shape=ds_shape,
  257. dtype=child.weight.dtype)
  258. if ds_id is not None:
  259. all_ds_ids[ds_id] = child.weight
  260. setattr(module, name, child)
  261. layer_policies[child.__class__](child, prefix + name + '.')
  262. else:
  263. load_module_recursive(
  264. child,
  265. prefix if (level == 0 and ckpt_type == 'pp') and skip_level_0_prefix else \
  266. prefix + name + '.',
  267. level + 1)
  268. load_module_recursive(r_module)
  269. embedding_weight = None
  270. for n, p in r_module.named_parameters():
  271. if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
  272. embedding_weight = p
  273. if embedding_weight is not None and r_module.lm_head.weight.is_meta:
  274. r_module.lm_head.weight = embedding_weight
  275. for sd_ in sd:
  276. del sd_
  277. sd = None
  278. gc.collect()