load_checkpoint.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from torch import nn
  3. from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
  4. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  5. from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
  6. from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference
  7. from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference
  8. import deepspeed.ops.transformer as transformer_inference
  9. from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
  10. import torch
  11. import gc
  12. from deepspeed.accelerator import get_accelerator
  13. def load_model_with_checkpoint(r_module,
  14. sd,
  15. mp_replace,
  16. ckpt_type,
  17. ckpt_mp_size,
  18. weight_quantizer=None,
  19. rank=0,
  20. container=None):
  21. error_msgs = []
  22. def transpose(data):
  23. with torch.no_grad():
  24. data = data.contiguous()
  25. data1 = data.transpose(-1, -2).reshape(-1)
  26. data.reshape(-1).copy_(data1)
  27. data1 = None
  28. return data.reshape(data.shape[-1], data.shape[-2])
  29. def load(module, prefix):
  30. args = (sd[0], prefix, {}, True, [], [], error_msgs)
  31. if hasattr(module, 'weight'):
  32. module.weight = mp_replace.copy(module.weight.data, sd[0][prefix + 'weight'])
  33. if prefix + 'bias' in sd[0].keys():
  34. if module.bias.data.is_meta:
  35. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  36. module.bias = torch.nn.parameter.Parameter(
  37. data=torch.empty_like(module.bias.data,
  38. device="cpu"),
  39. requires_grad=module.bias.data.requires_grad)
  40. module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
  41. args = None
  42. gc.collect()
  43. def load_transformer_layer(module, prefix):
  44. if ckpt_type == "tp":
  45. def load_parameters(module, prefix):
  46. for n, p in module.named_parameters():
  47. if prefix + n in sd[0] and len(n.split('.')) == 1:
  48. if type(sd[0][prefix + n]) is list:
  49. tmp_data, scale = sd[0][prefix + n]
  50. tmp_data = tmp_data
  51. scale = scale.to(get_accelerator().current_device_name())
  52. # set the quantizer number of groups using the checkpoint scale shape
  53. weight_quantizer.num_groups = scale.shape[0]
  54. else:
  55. tmp_data = sd[0][prefix + n].to(
  56. get_accelerator().current_device_name())
  57. scale = None
  58. src_shape = tmp_data.shape
  59. dst_shape = p.shape
  60. inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
  61. outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
  62. if (len(src_shape) == 2 and len(dst_shape) == 2):
  63. if (src_shape[inner_dim] == dst_shape[0]
  64. and src_shape[outer_dim] == dst_shape[1]):
  65. if tmp_data.dtype != torch.int8:
  66. p = weight_quantizer.quantize(
  67. transpose(tmp_data) if weight_quantizer.
  68. q_int8 else tmp_data)
  69. else:
  70. p = torch.nn.parameter.Parameter(tmp_data,
  71. requires_grad=False)
  72. p.scale = scale
  73. setattr(module, n, p)
  74. else:
  75. dim = inner_dim if src_shape[inner_dim] != dst_shape[
  76. 0] else outer_dim
  77. dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
  78. if src_shape[dim] > dst_shape[dim1]:
  79. weight_partition = torch.split(
  80. tmp_data,
  81. dst_shape[dim1],
  82. dim=dim)[rank].to(
  83. get_accelerator().current_device_name())
  84. assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
  85. '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
  86. Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
  87. scale = scale.view(
  88. -1)[weight_quantizer.num_groups *
  89. (rank + 1):].reshape(
  90. weight_quantizer.num_groups,
  91. -1).contiguous()
  92. else:
  93. assert tmp_data.dtype != torch.int8, \
  94. '''Merging of the checkpoints are not supported when using INT8 checkpoint! \
  95. Please use a as many GPUs as TP-size for the checkpoint'''
  96. all_data = [
  97. sd[j][prefix +
  98. n] if type(sd[j][prefix + n]) is list else
  99. sd[j][prefix + n].to(
  100. get_accelerator().current_device_name())
  101. for j in range(len(sd))
  102. ]
  103. # Check if the weight tensor is for the QKV parameter
  104. if src_shape[1] == (3 *
  105. src_shape[0]) // ckpt_mp_size:
  106. qkv_size = src_shape[outer_dim] // 3
  107. src_split = [
  108. torch.split(src[0].data,
  109. qkv_size,
  110. dim=outer_dim)
  111. for src in all_data
  112. ]
  113. weight_partition = torch.cat([
  114. torch.cat([qkv_s[i] for qkv_s in src_split],
  115. axis=outer_dim)
  116. for i in range(len(src_split[0]))
  117. ],
  118. dim=dim)
  119. else:
  120. weight_partition = torch.cat([
  121. ad[0].to(
  122. get_accelerator().current_device_name())
  123. if type(ad) is list else ad
  124. for ad in all_data
  125. ],
  126. dim=dim)
  127. if tmp_data.dtype == torch.int8:
  128. scale = torch.cat([
  129. ad[1].to(
  130. get_accelerator().current_device_name())
  131. for ad in all_data
  132. ],
  133. dim=dim)
  134. if tmp_data.dtype != torch.int8:
  135. weight_partition = weight_quantizer.quantize(
  136. transpose(weight_partition), \
  137. parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
  138. weight_quantizer.quantize(weight_partition)
  139. else:
  140. weight_partition = torch.nn.parameter.Parameter(
  141. weight_partition,
  142. requires_grad=False)
  143. weight_partition.scale = scale
  144. setattr(module, n, weight_partition)
  145. else:
  146. if src_shape[0] == dst_shape[0]:
  147. p.data.copy_(tmp_data)
  148. else:
  149. if src_shape[0] > dst_shape[0]:
  150. bias_split = torch.split(
  151. tmp_data,
  152. dst_shape[-1])[rank].to(get_accelerator(
  153. ).current_device_name()).contiguous()
  154. p.data.copy_(bias_split)
  155. else:
  156. # Check if the weight tensor is for the QKV parameter
  157. if src_shape[0] == (3 * r_module.config.hidden_size
  158. ) // ckpt_mp_size:
  159. qkv_size = src_shape[0] // 3
  160. src_split = [
  161. torch.split(sd[j][prefix + n],
  162. qkv_size,
  163. dim=0) for j in range(len(sd))
  164. ]
  165. p.data.copy_(
  166. torch.cat(
  167. [
  168. torch.cat([
  169. qkv_s[i] for qkv_s in src_split
  170. ],
  171. axis=0)
  172. for i in range(len(src_split[0]))
  173. ],
  174. dim=0).to(get_accelerator(
  175. ).current_device_name()).contiguous())
  176. else:
  177. p.data.copy_(
  178. torch.cat(
  179. [
  180. sd[j][prefix + n]
  181. for j in range(len(sd))
  182. ],
  183. dim=0).to(get_accelerator(
  184. ).current_device_name()).contiguous())
  185. load_parameters(module, prefix)
  186. for n, child in module.named_children():
  187. load_parameters(child, prefix + n + '.')
  188. else:
  189. container.load_params(module, sd[0], weight_quantizer, mp_replace, prefix)
  190. try:
  191. import transformers
  192. OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
  193. except:
  194. OPTLearnedPositionalEmbedding = None
  195. layer_policies = {
  196. nn.Linear: load,
  197. nn.Embedding: load,
  198. nn.LayerNorm: load,
  199. EmbeddingLayer: load,
  200. LinearLayer: load,
  201. Normalize: load,
  202. transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
  203. DeepSpeedBloomInference: load_transformer_layer,
  204. DeepSpeedGPTInference: load_transformer_layer,
  205. DeepSpeedBERTInference: load_transformer_layer,
  206. DeepSpeedMegatronGPTInference: load_transformer_layer,
  207. DeepSpeedOPTInference: load_transformer_layer,
  208. OPTLearnedPositionalEmbedding: load,
  209. OPTEmbedding: load
  210. }
  211. all_ds_ids = {}
  212. def load_module_recursive(module, prefix='', level=0):
  213. for name, child in module.named_children():
  214. if child.__class__ in layer_policies:
  215. checking_key = prefix + name + '.'
  216. if not any(checking_key in item for item in sd[0].keys()):
  217. if hasattr(child, 'weight') and \
  218. (hasattr(child.weight, 'ds_id') and \
  219. child.weight.ds_id in all_ds_ids):
  220. prefix1 = all_ds_ids[child.weight.ds_id]
  221. if child.__class__ is nn.Linear:
  222. child = LinearLayer(weight=all_ds_ids[child.weight.ds_id])
  223. setattr(module, name, child)
  224. continue
  225. child_params = list(child.parameters())
  226. if len(child_params) > 0 and (child_params[0].numel() == 0
  227. or child_params[0].is_meta):
  228. if child.weight.is_meta:
  229. ds_shape = child.weight.shape
  230. else:
  231. ds_shape = child.weight.ds_shape
  232. if child.__class__ is nn.LayerNorm:
  233. child = Normalize(dim=ds_shape[-1],
  234. dtype=child.weight.dtype,
  235. eps=child.eps)
  236. setattr(module, name, child)
  237. elif child.__class__ is nn.Linear:
  238. child = LinearLayer(weight_shape=child.weight.shape,
  239. bias=child.bias)
  240. setattr(module, name, child)
  241. elif child.__class__ is OPTLearnedPositionalEmbedding:
  242. child = OPTEmbedding(weight_shape=ds_shape)
  243. setattr(module, name, child)
  244. else:
  245. ds_id = None
  246. if hasattr(child.weight, 'ds_id'):
  247. ds_id = child.weight.ds_id
  248. child = EmbeddingLayer(weight_shape=ds_shape,
  249. dtype=child.weight.dtype)
  250. if ds_id is not None:
  251. all_ds_ids[ds_id] = child.weight
  252. setattr(module, name, child)
  253. layer_policies[child.__class__](child, prefix + name + '.')
  254. else:
  255. load_module_recursive(
  256. child,
  257. prefix if (level == 0 and ckpt_type == 'pp') and container.policy.use_load_prefix else \
  258. prefix + name + '.',
  259. level + 1)
  260. load_module_recursive(r_module)
  261. embedding_weight = None
  262. for n, p in r_module.named_parameters():
  263. if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
  264. embedding_weight = p
  265. if embedding_weight is not None and r_module.lm_head.weight.is_meta:
  266. r_module.lm_head.weight = embedding_weight
  267. for sd_ in sd:
  268. del sd_
  269. sd = None
  270. gc.collect()