load_checkpoint.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from torch import nn
  5. from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
  6. from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
  7. from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
  8. from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference
  9. from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference
  10. from deepspeed.model_implementations.transformers.ds_llama2 import DeepSpeedLlama2Inference
  11. import deepspeed.ops.transformer as transformer_inference
  12. from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding, RMSNormalize
  13. import torch
  14. import gc
  15. from deepspeed.accelerator import get_accelerator
  16. import re
  17. def load_model_with_checkpoint(r_module,
  18. sd,
  19. mp_replace,
  20. ckpt_type,
  21. ckpt_mp_size,
  22. weight_quantizer=None,
  23. rank=0,
  24. container=None):
  25. error_msgs = []
  26. def prefix_check():
  27. # if keys start with 'model.' or 'transformer.', don't skip level 0 prefix
  28. for key in sd[0].keys():
  29. # OPT models
  30. if re.match("^model[.]", key):
  31. return False
  32. # BLOOM models
  33. if re.match("^transformer[.]", key):
  34. return False
  35. return True
  36. skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix
  37. def transpose(data):
  38. with torch.no_grad():
  39. data = data.contiguous()
  40. data1 = data.transpose(-1, -2).reshape(-1)
  41. data.reshape(-1).copy_(data1)
  42. data1 = None
  43. return data.reshape(data.shape[-1], data.shape[-2])
  44. def load(module, prefix):
  45. args = (sd[0], prefix, {}, True, [], [], error_msgs)
  46. if hasattr(module, 'weight'):
  47. module.weight = mp_replace.copy(module.weight.data, sd[0][prefix + 'weight'])
  48. if prefix + 'bias' in sd[0].keys():
  49. if module.bias.data.is_meta:
  50. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  51. module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
  52. requires_grad=module.bias.data.requires_grad)
  53. module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
  54. args = None
  55. gc.collect()
  56. def load_transformer_layer(module, prefix):
  57. if ckpt_type == "tp":
  58. def load_parameters(module, prefix):
  59. for n, p in module.named_parameters():
  60. if prefix + n in sd[0] and len(n.split('.')) == 1:
  61. if type(sd[0][prefix + n]) is list:
  62. tmp_data, scale = sd[0][prefix + n]
  63. tmp_data = tmp_data
  64. scale = scale.to(get_accelerator().current_device_name())
  65. # set the quantizer number of groups using the checkpoint scale shape
  66. weight_quantizer.num_groups = scale.shape[0]
  67. else:
  68. tmp_data = sd[0][prefix + n].to(get_accelerator().current_device_name())
  69. scale = None
  70. src_shape = tmp_data.shape
  71. dst_shape = p.shape
  72. inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
  73. outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
  74. if (len(src_shape) == 2 and len(dst_shape) == 2):
  75. if (src_shape[inner_dim] == dst_shape[0] and src_shape[outer_dim] == dst_shape[1]):
  76. if tmp_data.dtype != torch.int8:
  77. p = weight_quantizer.quantize(
  78. transpose(tmp_data) if weight_quantizer.q_int8 else tmp_data)
  79. else:
  80. p = torch.nn.parameter.Parameter(tmp_data, requires_grad=False)
  81. p.scale = scale
  82. setattr(module, n, p)
  83. else:
  84. dim = inner_dim if src_shape[inner_dim] != dst_shape[0] else outer_dim
  85. dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
  86. if src_shape[dim] > dst_shape[dim1]:
  87. weight_partition = torch.split(tmp_data, dst_shape[dim1], dim=dim)[rank].to(
  88. get_accelerator().current_device_name())
  89. assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
  90. '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
  91. Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
  92. scale = scale.view(-1)[weight_quantizer.num_groups * (rank + 1):].reshape(
  93. weight_quantizer.num_groups, -1).contiguous()
  94. else:
  95. assert tmp_data.dtype != torch.int8, \
  96. '''Merging of the checkpoints are not supported when using INT8 checkpoint! \
  97. Please use a as many GPUs as TP-size for the checkpoint'''
  98. all_data = [
  99. sd[j][prefix + n] if type(sd[j][prefix + n]) is list else sd[j][prefix + n].to(
  100. get_accelerator().current_device_name()) for j in range(len(sd))
  101. ]
  102. # Check if the weight tensor is for the QKV parameter
  103. if src_shape[1] == (3 * src_shape[0]) // ckpt_mp_size:
  104. qkv_size = src_shape[outer_dim] // 3
  105. src_split = [
  106. torch.split(src[0].data, qkv_size, dim=outer_dim) for src in all_data
  107. ]
  108. weight_partition = torch.cat([
  109. torch.cat([qkv_s[i] for qkv_s in src_split], axis=outer_dim)
  110. for i in range(len(src_split[0]))
  111. ],
  112. dim=dim)
  113. else:
  114. weight_partition = torch.cat([
  115. ad[0].to(get_accelerator().current_device_name())
  116. if type(ad) is list else ad for ad in all_data
  117. ],
  118. dim=dim)
  119. if tmp_data.dtype == torch.int8:
  120. scale = torch.cat(
  121. [ad[1].to(get_accelerator().current_device_name()) for ad in all_data],
  122. dim=dim)
  123. if tmp_data.dtype != torch.int8:
  124. weight_partition = weight_quantizer.quantize(
  125. transpose(weight_partition), \
  126. parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
  127. weight_quantizer.quantize(weight_partition)
  128. else:
  129. weight_partition = torch.nn.parameter.Parameter(weight_partition,
  130. requires_grad=False)
  131. weight_partition.scale = scale
  132. setattr(module, n, weight_partition)
  133. else:
  134. if src_shape[0] == dst_shape[0]:
  135. p.data.copy_(tmp_data)
  136. else:
  137. if src_shape[0] > dst_shape[0]:
  138. bias_split = torch.split(tmp_data, dst_shape[-1])[rank].to(
  139. get_accelerator().current_device_name()).contiguous()
  140. p.data.copy_(bias_split)
  141. else:
  142. # Check if the weight tensor is for the QKV parameter
  143. if src_shape[0] == (3 * r_module.config.hidden_size) // ckpt_mp_size:
  144. qkv_size = src_shape[0] // 3
  145. src_split = [
  146. torch.split(sd[j][prefix + n], qkv_size, dim=0) for j in range(len(sd))
  147. ]
  148. p.data.copy_(
  149. torch.cat([
  150. torch.cat([qkv_s[i] for qkv_s in src_split], axis=0)
  151. for i in range(len(src_split[0]))
  152. ],
  153. dim=0).to(get_accelerator().current_device_name()).contiguous())
  154. else:
  155. p.data.copy_(
  156. torch.cat([sd[j][prefix + n] for j in range(len(sd))],
  157. dim=0).to(get_accelerator().current_device_name()).contiguous())
  158. load_parameters(module, prefix)
  159. for n, child in module.named_children():
  160. load_parameters(child, prefix + n + '.')
  161. else:
  162. container.load_params(module, sd[0], weight_quantizer, mp_replace, prefix)
  163. try:
  164. import transformers
  165. OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
  166. if hasattr(transformers.models, "llama"):
  167. LlamaRMSNorm = transformers.models.llama.modeling_llama.LlamaRMSNorm
  168. else:
  169. LlamaRMSNorm = None
  170. except:
  171. OPTLearnedPositionalEmbedding = None
  172. try:
  173. from fairscale.nn.model_parallel.layers import (
  174. ColumnParallelLinear,
  175. ParallelEmbedding,
  176. RowParallelLinear,
  177. )
  178. except:
  179. ColumnParallelLinear = None
  180. ParallelEmbedding = None
  181. RowParallelLinear = None
  182. try:
  183. from llama.model import RMSNorm
  184. except:
  185. RMSNorm = None
  186. layer_policies = {
  187. nn.Linear: load,
  188. nn.Embedding: load,
  189. nn.LayerNorm: load,
  190. EmbeddingLayer: load,
  191. LinearLayer: load,
  192. Normalize: load,
  193. transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
  194. DeepSpeedBloomInference: load_transformer_layer,
  195. DeepSpeedGPTInference: load_transformer_layer,
  196. DeepSpeedBERTInference: load_transformer_layer,
  197. DeepSpeedMegatronGPTInference: load_transformer_layer,
  198. DeepSpeedOPTInference: load_transformer_layer,
  199. DeepSpeedLlama2Inference: load_transformer_layer,
  200. OPTLearnedPositionalEmbedding: load,
  201. OPTEmbedding: load,
  202. LlamaRMSNorm: load,
  203. RMSNormalize: load,
  204. ColumnParallelLinear: load,
  205. ParallelEmbedding: load,
  206. RowParallelLinear: load,
  207. RMSNorm: load
  208. }
  209. all_ds_ids = {}
  210. def load_module_recursive(module, prefix='', level=0):
  211. for name, child in module.named_children():
  212. if child.__class__ in layer_policies:
  213. checking_key = prefix + name + '.'
  214. if not any(checking_key in item for item in sd[0].keys()):
  215. if hasattr(child, 'weight') and \
  216. (hasattr(child.weight, 'ds_id') and \
  217. child.weight.ds_id in all_ds_ids):
  218. prefix1 = all_ds_ids[child.weight.ds_id]
  219. if child.__class__ is nn.Linear:
  220. child = LinearLayer(weight=all_ds_ids[child.weight.ds_id])
  221. setattr(module, name, child)
  222. continue
  223. child_params = list(child.parameters())
  224. if len(child_params) > 0 and (child_params[0].numel() == 0 or child_params[0].is_meta):
  225. if child.weight.is_meta:
  226. ds_shape = child.weight.shape
  227. else:
  228. ds_shape = child.weight.ds_shape
  229. if child.__class__ is nn.LayerNorm:
  230. child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
  231. setattr(module, name, child)
  232. elif child.__class__ in [nn.Linear, ColumnParallelLinear, RowParallelLinear]:
  233. child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias)
  234. setattr(module, name, child)
  235. elif child.__class__ is OPTLearnedPositionalEmbedding:
  236. child = OPTEmbedding(weight_shape=ds_shape)
  237. setattr(module, name, child)
  238. elif child.__class__ in [LlamaRMSNorm, RMSNorm]:
  239. child = RMSNormalize(dim=ds_shape[-1],
  240. dtype=child.weight.dtype,
  241. eps=child.eps if hasattr(child, 'eps') else child.variance_epsilon)
  242. setattr(module, name, child)
  243. else:
  244. ds_id = None
  245. if hasattr(child.weight, 'ds_id'):
  246. ds_id = child.weight.ds_id
  247. child = EmbeddingLayer(weight_shape=ds_shape, dtype=child.weight.dtype)
  248. if ds_id is not None:
  249. all_ds_ids[ds_id] = child.weight
  250. setattr(module, name, child)
  251. layer_policies[child.__class__](child, prefix + name + '.')
  252. else:
  253. load_module_recursive(
  254. child,
  255. prefix if (level == 0 and ckpt_type == 'pp') and skip_level_0_prefix else \
  256. prefix + name + '.',
  257. level + 1)
  258. load_module_recursive(r_module)
  259. for sd_ in sd:
  260. del sd_
  261. sd = None
  262. gc.collect()