auto_tp.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Automatic Tensor Parallelism
  5. import re
  6. from torch import nn
  7. from .replace_policy import replace_policies
  8. from typing import Optional
  9. import torch
  10. from deepspeed import comm as dist
  11. from .layers import LinearAllreduce, LinearLayer
  12. from deepspeed.accelerator import get_accelerator
  13. from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
  14. class ReplaceWithTensorSlicing:
  15. def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
  16. if mp_group is not None:
  17. self.gpu_index = dist.get_rank(group=mp_group)
  18. else:
  19. self.gpu_index = 0
  20. self.out_dim = out_dim
  21. self.in_dim = in_dim
  22. self.mp_size = mp_size
  23. def merge_assert(self, dim1, dim2):
  24. assert dim1 > dim2, \
  25. 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\
  26. for merging your checkpoints before replacing the transformer layer with\
  27. inference-kernels'
  28. def strided_copy(self,
  29. dst: Optional[torch.Tensor],
  30. src: Optional[torch.Tensor],
  31. num_splits: int,
  32. int8: bool = False,
  33. allocate_tensor: bool = False):
  34. if src is None:
  35. return src
  36. src_shape = src.shape
  37. dst_shape = dst.shape
  38. outer_dim = 0 if int8 else -1
  39. if allocate_tensor:
  40. dst = torch.empty_like(dst)
  41. src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
  42. if (len(src_shape) == 2 and len(dst_shape) == 2):
  43. if src_shape[outer_dim] == dst_shape[self.out_dim]:
  44. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  45. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  46. if hasattr(src, 'scale'):
  47. dst.scale = src.scale
  48. return dst
  49. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  50. qkv_size = dst_shape[self.out_dim] // num_splits
  51. qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
  52. weight_split = [
  53. torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
  54. ]
  55. dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
  56. weight_split[self.gpu_index].shape)
  57. else:
  58. if src_shape[0] == dst_shape[0]:
  59. return torch.nn.parameter.Parameter(src)
  60. qkv_size = dst_shape[0] // num_splits
  61. qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
  62. bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
  63. dst.data.copy_(bias_split[self.gpu_index].contiguous())
  64. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  65. if hasattr(src, 'scale'):
  66. dst.scale = src.scale
  67. return dst
  68. def copy(self, dst, src, int8=False, allocate_tensor=False):
  69. if src is None:
  70. return src
  71. assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
  72. if allocate_tensor:
  73. dst = torch.empty_like(dst)
  74. outer_dim = 0 if int8 else 1
  75. inner_dim = 1 if int8 else 0
  76. src_shape = src.shape
  77. dst_shape = dst.shape
  78. if (len(src_shape) == 2 and len(dst_shape) == 2):
  79. if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
  80. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  81. else:
  82. if src_shape[inner_dim] != dst_shape[self.in_dim]:
  83. self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
  84. dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
  85. src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
  86. else:
  87. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  88. dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
  89. src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
  90. else:
  91. if src_shape[0] == dst_shape[0]:
  92. dst = src if src.dtype == dst.dtype else dst.data.copy_(src)
  93. else:
  94. dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
  95. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  96. if hasattr(src, 'scale'):
  97. dst.scale = src.scale
  98. return dst
  99. class Loading():
  100. def load_buffer(module, state_dict, prefix):
  101. for name in module._buffers.keys():
  102. if module._buffers[name].data.is_meta:
  103. module._buffers[name] = torch.nn.parameter.Parameter(
  104. data=torch.empty_like(module._buffers[name].data, device="cpu"),
  105. requires_grad=module._buffers[name].data.requires_grad)
  106. if prefix + name in state_dict.keys():
  107. module._buffers[name].data.copy_(state_dict[prefix + name])
  108. def load(module, state_dict, prefix, mp_group=None):
  109. mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
  110. if hasattr(module, 'weight'):
  111. if module.weight.data.is_meta:
  112. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  113. module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"),
  114. requires_grad=module.weight.data.requires_grad)
  115. if 'query_key_value' in prefix:
  116. module.weight = mp_replace.strided_copy(module.weight.data,
  117. state_dict[prefix + 'weight'],
  118. num_splits=3)
  119. else:
  120. module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
  121. else:
  122. if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
  123. if module.norm.weight.data.is_meta:
  124. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  125. module.norm.weight = torch.nn.parameter.Parameter(
  126. data=torch.empty_like(module.norm.weight.data, device="cpu"),
  127. requires_grad=module.norm.weight.data.requires_grad)
  128. module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
  129. if prefix + 'bias' in state_dict.keys():
  130. if hasattr(module, 'bias'):
  131. if module.bias.data.is_meta:
  132. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  133. module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
  134. requires_grad=module.bias.data.requires_grad)
  135. module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
  136. else:
  137. if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
  138. if module.norm.bias.data.is_meta:
  139. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  140. module.norm.bias = torch.nn.parameter.Parameter(
  141. data=torch.empty_like(module.norm.bias.data, device="cpu"),
  142. requires_grad=module.norm.bias.data.requires_grad)
  143. module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
  144. class AutoTP():
  145. def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
  146. self.module = module
  147. self.all_reduce_linears = all_reduce_linears
  148. self.prefix = prefix
  149. self.state_dict = state_dict
  150. self.mp_size = None
  151. self.mp_group = None
  152. self.linear_layer_setting = linear_layer_setting
  153. self.orig_layer_impl = orig_layer_impl
  154. self.linear_policies = None
  155. self.conv_linear_layer = False
  156. def in_module_list(module, module_list):
  157. for item in module_list:
  158. if type(item).__name__ == type(module).__name__:
  159. return True
  160. return False
  161. def get_module_list(model):
  162. mlist = []
  163. for child in model.children():
  164. if isinstance(child, nn.ModuleList):
  165. for module in child.children():
  166. if not mlist:
  167. mlist = [module]
  168. elif not AutoTP.in_module_list(module, mlist):
  169. mlist = mlist + [module]
  170. else:
  171. mlist = mlist + AutoTP.get_module_list(child)
  172. return mlist
  173. def supported(model):
  174. unsupported = ['deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
  175. model = str(model)
  176. key = re.search(r": (.*?)Model", model)
  177. if key is None:
  178. key = re.search(r": (.*?)Stack", model)
  179. if key is None:
  180. key = re.match(r"(.*?)Model", model)
  181. assert key is not None, "Not able to determine model policy automatically. Please provide policy."
  182. if key.group(1).lower() in unsupported:
  183. return False
  184. return True
  185. def get_layers(parent, module):
  186. layer_list = []
  187. for key, submodule in module._modules.items():
  188. if isinstance(submodule, nn.Linear):
  189. layer_list = layer_list + [parent + "." + key]
  190. elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  191. layer_list = layer_list + ["ln"]
  192. else:
  193. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  194. return layer_list
  195. def update_policy_list(policy_list, new_module, new_gems):
  196. if len(policy_list):
  197. for i, policy in enumerate(policy_list):
  198. # if module already exists in policy, combine gems and remove duplicates
  199. if policy[0] == type(new_module):
  200. new_gems = set(new_gems + policy[1])
  201. policy_list[i] = tuple([type(new_module), new_gems])
  202. return policy_list
  203. policy_list.append(tuple([type(new_module), new_gems]))
  204. return policy_list
  205. def kernel_supported(module_list):
  206. policy = []
  207. for plcy in replace_policies:
  208. # instantiate a throw-away policy in order to populate the _orig_layer_class
  209. _ = plcy(None)
  210. if isinstance(plcy._orig_layer_class, list):
  211. for orig_layer_class in plcy._orig_layer_class:
  212. policy.append(orig_layer_class)
  213. elif plcy._orig_layer_class is not None:
  214. policy.append(plcy._orig_layer_class)
  215. for child in module_list:
  216. if child.__class__ in policy:
  217. return True
  218. return False
  219. def tp_parser(model):
  220. policy_list = []
  221. module_list = []
  222. layer_list = []
  223. gem_list = []
  224. module_list = AutoTP.get_module_list(model)
  225. assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  226. if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
  227. for module in module_list:
  228. for key, submodule in module._modules.items():
  229. if isinstance(submodule, nn.Linear):
  230. layer_list = layer_list + ["." + key]
  231. elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  232. layer_list = layer_list + ["ln"]
  233. else:
  234. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  235. for i, layer in enumerate(layer_list):
  236. if layer == 'ln':
  237. if layer_list[i - 1] != 'ln':
  238. gem_list = gem_list + [layer_list[i - 1]]
  239. elif 'out_proj' in layer:
  240. gem_list = gem_list + [layer]
  241. elif 'o_proj' in layer:
  242. gem_list = gem_list + [layer]
  243. elif 'down_proj' in layer:
  244. gem_list = gem_list + [layer]
  245. elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
  246. gem_list = gem_list + [layer]
  247. elif 'self_attention.dense' in layer and 'falcon' in str(
  248. type(module)): # this is a hack to get the right linear layer for this model!
  249. gem_list = gem_list + [layer]
  250. layer_list = []
  251. if gem_list != []:
  252. gem_list = list(set(gem_list))
  253. policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
  254. gem_list = []
  255. assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  256. if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
  257. return policy_list
  258. def set_tensor_parallel_config(self, mp_size, mp_group):
  259. self.mp_size = mp_size
  260. self.mp_group = mp_group
  261. def _replace(self, child, name, conv_linear_layer):
  262. if getattr(child, "replaced", False) == True:
  263. return
  264. weight_shape = child.weight.shape
  265. if name == 'attn.Wqkv' and self.module._get_name() == 'MPTBlock':
  266. # MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size]
  267. # instead of [num_head,3,head_dim,hidden_size]
  268. new_weight = torch.empty((
  269. weight_shape[0] // self.mp_size,
  270. weight_shape[1],
  271. ),
  272. device=child.weight.device,
  273. dtype=child.weight.dtype)
  274. reversed_dim = True
  275. mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, out_dim=0)
  276. # todo: can we remove new tensor allocation if we use strided copy?
  277. mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim)
  278. setattr(child, "replaced", True)
  279. return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None)
  280. mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
  281. if name in self.all_reduce_linears:
  282. # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
  283. # else [weight_shape[0], weight_shape[1] // mp_size]
  284. if self.conv_linear_layer:
  285. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  286. data = child.weight.data.split(
  287. (weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1)
  288. data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  289. setattr(child, "replaced", True)
  290. return LinearAllreduce(data, child.bias if child.bias is None else \
  291. torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
  292. else:
  293. # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
  294. # else [weight_shape[0] // mp_size, weight_shape[1]]
  295. if self.conv_linear_layer:
  296. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  297. if require_tp_fused_qkvw(name, self.mp_size):
  298. #for detecting fused type
  299. module_str = str(self.module).strip()
  300. #The copy is a regular copy, The shape of dst and src is the same
  301. data = prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index)
  302. bias_data = None if child.bias is None else prepare_tp_fused_qkvw(
  303. module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
  304. get_accelerator().current_device_name())
  305. else:
  306. data = child.weight.data.split((weight_shape[0]) // self.mp_size,
  307. dim=1 if self.conv_linear_layer else 0)
  308. data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  309. if child.bias is not None:
  310. bias_data = child.bias.data.split(
  311. (weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0)
  312. bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  313. else:
  314. bias_data = None
  315. setattr(child, "replaced", True)
  316. return LinearLayer(weight=data.to(get_accelerator().current_device_name()), bias=bias_data)
  317. def _slice_embedding(self, child, name, conv_linear_layer):
  318. if getattr(child, "replaced", False) == True:
  319. return
  320. mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
  321. if hasattr(child.weight, 'ds_tensor'):
  322. data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1)
  323. else:
  324. data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1)
  325. data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  326. new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size)
  327. new_embedding.weight.data.copy_(data)
  328. setattr(child, "replaced", True)
  329. return new_embedding
  330. def update_mp_params(self, child):
  331. if getattr(child, "replaced", False) == True:
  332. return
  333. for param in [
  334. "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
  335. "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads"
  336. ]:
  337. if hasattr(child, param):
  338. param_val = getattr(child, param)
  339. assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
  340. setattr(child, param, param_val // self.mp_size)
  341. setattr(child, "replaced", True)
  342. def update_linear_policies(self):
  343. self.conv_linear_layer = False
  344. if self.linear_layer_setting is not None:
  345. self.linear_policies = {self.linear_layer_setting[0]: self._replace}
  346. if len(self.linear_layer_setting) == 2:
  347. self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding})
  348. else:
  349. import transformers
  350. if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block:
  351. try:
  352. self.conv_linear_layer = True
  353. self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace}
  354. except ImportError:
  355. self.linear_policies = {nn.Linear: self._replace}
  356. else:
  357. self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding}
  358. def _replace_module(self, r_module, prev_name='', prev_class_name=''):
  359. for name, child in r_module.named_children():
  360. if prev_class_name == "":
  361. class_name = prev_name
  362. elif prev_name == "":
  363. class_name = prev_class_name
  364. else:
  365. class_name = prev_class_name + '.' + prev_name
  366. checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
  367. if (child.__class__ in [nn.Linear, nn.Embedding, nn.LayerNorm]
  368. or child._get_name() in ["LPLayerNorm", "SharedEmbedding"]) and self.state_dict is not None:
  369. if any(checking_key in item for item in self.state_dict):
  370. Loading.load(child, self.state_dict, checking_key, self.mp_group)
  371. else:
  372. continue
  373. if len(child._buffers) != 0 and self.state_dict is not None:
  374. Loading.load_buffer(child, self.state_dict, checking_key)
  375. if child.__class__ in self.linear_policies:
  376. setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
  377. self.conv_linear_layer))
  378. elif any(isinstance(child, lp) for lp in self.linear_policies):
  379. # Added for falcon model support
  380. # Note: isinstance will account for class inheritance, child.__class__ does not
  381. key = None
  382. for lp in self.linear_policies:
  383. if isinstance(child, lp):
  384. key = lp
  385. break
  386. assert key is not None
  387. setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name,
  388. self.conv_linear_layer))
  389. else:
  390. self.update_mp_params(child)
  391. self._replace_module(child, name, class_name)
  392. return r_module