auto_tp.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Automatic Tensor Parallelism
  5. import re
  6. from torch import nn
  7. from .replace_policy import replace_policies
  8. from typing import Optional
  9. import torch
  10. from deepspeed import comm as dist
  11. from .layers import LinearAllreduce, LinearLayer
  12. from deepspeed.accelerator import get_accelerator
  13. from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
  14. class ReplaceWithTensorSlicing:
  15. def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
  16. if mp_group is not None:
  17. self.gpu_index = dist.get_rank(group=mp_group)
  18. else:
  19. self.gpu_index = 0
  20. self.out_dim = out_dim
  21. self.in_dim = in_dim
  22. self.mp_size = mp_size
  23. def merge_assert(self, dim1, dim2):
  24. assert dim1 > dim2, \
  25. 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\
  26. for merging your checkpoints before replacing the transformer layer with\
  27. inference-kernels'
  28. def strided_copy(self,
  29. dst: Optional[torch.Tensor],
  30. src: Optional[torch.Tensor],
  31. num_splits: int,
  32. int8: bool = False,
  33. allocate_tensor: bool = False):
  34. if src is None:
  35. return src
  36. src_shape = src.shape
  37. dst_shape = dst.shape
  38. outer_dim = 0 if int8 else -1
  39. if allocate_tensor:
  40. dst = torch.empty_like(dst)
  41. src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
  42. if (len(src_shape) == 2 and len(dst_shape) == 2):
  43. if src_shape[outer_dim] == dst_shape[self.out_dim]:
  44. try:
  45. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  46. except:
  47. print(dst.shape, src.shape)
  48. exit()
  49. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  50. if hasattr(src, 'scale'):
  51. dst.scale = src.scale
  52. return dst
  53. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  54. qkv_size = dst_shape[self.out_dim] // num_splits
  55. qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
  56. weight_split = [
  57. torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
  58. ]
  59. dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
  60. weight_split[self.gpu_index].shape)
  61. else:
  62. if src_shape[0] == dst_shape[0]:
  63. return torch.nn.parameter.Parameter(src)
  64. qkv_size = dst_shape[0] // num_splits
  65. qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
  66. bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
  67. dst.data.copy_(bias_split[self.gpu_index].contiguous())
  68. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  69. if hasattr(src, 'scale'):
  70. dst.scale = src.scale
  71. return dst
  72. def copy(self, dst, src, int8=False, allocate_tensor=False):
  73. if src is None:
  74. return src
  75. assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
  76. if allocate_tensor:
  77. dst = torch.empty_like(dst)
  78. outer_dim = 0 if int8 else 1
  79. inner_dim = 1 if int8 else 0
  80. src_shape = src.shape
  81. dst_shape = dst.shape
  82. if (len(src_shape) == 2 and len(dst_shape) == 2):
  83. if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
  84. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  85. else:
  86. if src_shape[inner_dim] != dst_shape[self.in_dim]:
  87. self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
  88. dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
  89. src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
  90. else:
  91. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  92. dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
  93. src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
  94. else:
  95. if src_shape[0] == dst_shape[0]:
  96. dst = src if src.dtype == dst.dtype else dst.data.copy_(src)
  97. else:
  98. dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
  99. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  100. if hasattr(src, 'scale'):
  101. dst.scale = src.scale
  102. return dst
  103. class Loading():
  104. def is_load_module(module):
  105. load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
  106. load_layer_names = ["LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm"]
  107. return module.__class__ in load_layers or module._get_name() in load_layer_names
  108. def load_buffer(module, state_dict, prefix):
  109. for name in module._buffers.keys():
  110. if module._buffers[name].data.is_meta:
  111. module._buffers[name] = torch.nn.parameter.Parameter(
  112. data=torch.empty_like(module._buffers[name].data, device="cpu"),
  113. requires_grad=module._buffers[name].data.requires_grad)
  114. if prefix + name in state_dict.keys():
  115. module._buffers[name].data.copy_(state_dict[prefix + name])
  116. def load(module, state_dict, prefix, mp_group=None):
  117. mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
  118. if hasattr(module, 'weight'):
  119. if module.weight.data.is_meta:
  120. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  121. module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"),
  122. requires_grad=module.weight.data.requires_grad)
  123. if 'query_key_value' in prefix:
  124. module.weight = mp_replace.strided_copy(module.weight.data,
  125. state_dict[prefix + 'weight'],
  126. num_splits=3)
  127. else:
  128. module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
  129. else:
  130. if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
  131. if module.norm.weight.data.is_meta:
  132. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  133. module.norm.weight = torch.nn.parameter.Parameter(
  134. data=torch.empty_like(module.norm.weight.data, device="cpu"),
  135. requires_grad=module.norm.weight.data.requires_grad)
  136. module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
  137. if prefix + 'bias' in state_dict.keys():
  138. if hasattr(module, 'bias'):
  139. if module.bias.data.is_meta:
  140. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  141. module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
  142. requires_grad=module.bias.data.requires_grad)
  143. module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
  144. else:
  145. if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
  146. if module.norm.bias.data.is_meta:
  147. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  148. module.norm.bias = torch.nn.parameter.Parameter(
  149. data=torch.empty_like(module.norm.bias.data, device="cpu"),
  150. requires_grad=module.norm.bias.data.requires_grad)
  151. module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
  152. class AutoTP():
  153. def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
  154. self.module = module
  155. self.all_reduce_linears = all_reduce_linears
  156. self.prefix = prefix
  157. self.state_dict = state_dict
  158. self.mp_size = None
  159. self.mp_group = None
  160. self.linear_layer_setting = linear_layer_setting
  161. self.orig_layer_impl = orig_layer_impl
  162. self.linear_policies = None
  163. self.conv_linear_layer = False
  164. def in_module_list(module, module_list):
  165. for item in module_list:
  166. if type(item).__name__ == type(module).__name__:
  167. return True
  168. return False
  169. def get_module_list(model):
  170. mlist = []
  171. for child in model.children():
  172. if isinstance(child, nn.ModuleList):
  173. for module in child.children():
  174. if not mlist:
  175. mlist = [module]
  176. elif not AutoTP.in_module_list(module, mlist):
  177. mlist = mlist + [module]
  178. else:
  179. mlist = mlist + AutoTP.get_module_list(child)
  180. return mlist
  181. def supported(model):
  182. unsupported = ['deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
  183. model = str(model)
  184. key = re.search(r": (.*?)Model", model)
  185. if key is None:
  186. key = re.search(r": (.*?)Stack", model)
  187. if key is None:
  188. key = re.match(r"(.*?)Model", model)
  189. assert key is not None, "Not able to determine model policy automatically. Please provide policy."
  190. if key.group(1).lower() in unsupported:
  191. return False
  192. return True
  193. def get_layers(parent, module):
  194. layer_list = []
  195. for key, submodule in module._modules.items():
  196. if isinstance(submodule, nn.Linear):
  197. layer_list = layer_list + [parent + "." + key]
  198. elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  199. layer_list = layer_list + ["ln"]
  200. else:
  201. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  202. return layer_list
  203. def update_policy_list(policy_list, new_module, new_gems):
  204. if len(policy_list):
  205. for i, policy in enumerate(policy_list):
  206. # if module already exists in policy, combine gems and remove duplicates
  207. if policy[0] == type(new_module):
  208. new_gems = set(new_gems + policy[1])
  209. policy_list[i] = tuple([type(new_module), new_gems])
  210. return policy_list
  211. policy_list.append(tuple([type(new_module), new_gems]))
  212. return policy_list
  213. def kernel_supported(module_list):
  214. policy = []
  215. for plcy in replace_policies:
  216. # instantiate a throw-away policy in order to populate the _orig_layer_class
  217. _ = plcy(None)
  218. if isinstance(plcy._orig_layer_class, list):
  219. for orig_layer_class in plcy._orig_layer_class:
  220. policy.append(orig_layer_class)
  221. elif plcy._orig_layer_class is not None:
  222. policy.append(plcy._orig_layer_class)
  223. for child in module_list:
  224. if child.__class__ in policy:
  225. return True
  226. return False
  227. def tp_parser(model):
  228. policy_list = []
  229. module_list = []
  230. layer_list = []
  231. gem_list = []
  232. module_list = AutoTP.get_module_list(model)
  233. assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  234. if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
  235. for module in module_list:
  236. for key, submodule in module._modules.items():
  237. if isinstance(submodule, nn.Linear):
  238. layer_list = layer_list + ["." + key]
  239. elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  240. layer_list = layer_list + ["ln"]
  241. else:
  242. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  243. for i, layer in enumerate(layer_list):
  244. if layer == 'ln':
  245. if layer_list[i - 1] != 'ln':
  246. gem_list = gem_list + [layer_list[i - 1]]
  247. elif 'out_proj' in layer:
  248. gem_list = gem_list + [layer]
  249. elif 'o_proj' in layer:
  250. gem_list = gem_list + [layer]
  251. elif 'down_proj' in layer:
  252. gem_list = gem_list + [layer]
  253. elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
  254. gem_list = gem_list + [layer]
  255. elif 'self_attention.dense' in layer and 'falcon' in str(
  256. type(module)): # this is a hack to get the right linear layer for this model!
  257. gem_list = gem_list + [layer]
  258. layer_list = []
  259. if gem_list != []:
  260. gem_list = list(set(gem_list))
  261. policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
  262. gem_list = []
  263. assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  264. if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
  265. return policy_list
  266. def set_tensor_parallel_config(self, mp_size, mp_group):
  267. self.mp_size = mp_size
  268. self.mp_group = mp_group
  269. def _replace(self, child, name, conv_linear_layer):
  270. if getattr(child, "replaced", False) == True:
  271. return
  272. weight_shape = child.weight.shape
  273. mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
  274. if name in self.all_reduce_linears:
  275. # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
  276. # else [weight_shape[0], weight_shape[1] // mp_size]
  277. if self.conv_linear_layer:
  278. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  279. data = child.weight.data.split(
  280. (weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1)
  281. data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  282. setattr(child, "replaced", True)
  283. return LinearAllreduce(torch.nn.parameter.Parameter(data, requires_grad=False), child.bias if child.bias is None else \
  284. torch.nn.parameter.Parameter(child.bias.to(get_accelerator().current_device_name())), self.mp_group)
  285. else:
  286. # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
  287. # else [weight_shape[0] // mp_size, weight_shape[1]]
  288. if self.conv_linear_layer:
  289. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  290. if require_tp_fused_qkvw(name, self.mp_size):
  291. #for detecting fused type
  292. module_str = str(self.module).strip()
  293. #The copy is a regular copy, The shape of dst and src is the same
  294. data = prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index)
  295. bias_data = None if child.bias is None else prepare_tp_fused_qkvw(
  296. module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
  297. get_accelerator().current_device_name())
  298. else:
  299. data = child.weight.data.split((weight_shape[0]) // self.mp_size,
  300. dim=1 if self.conv_linear_layer else 0)
  301. data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  302. if child.bias is not None:
  303. bias_data = child.bias.data.split(
  304. (weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0)
  305. bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  306. bias_data = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
  307. else:
  308. bias_data = None
  309. setattr(child, "replaced", True)
  310. return LinearLayer(weight=torch.nn.parameter.Parameter(data.to(get_accelerator().current_device_name()), requires_grad=False), \
  311. bias=bias_data)
  312. def _slice_embedding(self, child, name, conv_linear_layer):
  313. if getattr(child, "replaced", False) == True:
  314. return
  315. mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
  316. if hasattr(child.weight, 'ds_tensor'):
  317. data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1)
  318. else:
  319. data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1)
  320. data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  321. data = torch.nn.parameter.Parameter(data, requires_grad=False)
  322. new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size)
  323. new_embedding.weight.data.copy_(data)
  324. setattr(child, "replaced", True)
  325. return new_embedding
  326. def update_mp_params(self, child):
  327. if getattr(child, "replaced", False) == True:
  328. return
  329. for param in [
  330. "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
  331. "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads"
  332. ]:
  333. if hasattr(child, param):
  334. param_val = getattr(child, param)
  335. assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
  336. setattr(child, param, param_val // self.mp_size)
  337. setattr(child, "replaced", True)
  338. def update_linear_policies(self):
  339. self.conv_linear_layer = False
  340. if self.linear_layer_setting is not None:
  341. self.linear_policies = {self.linear_layer_setting[0]: self._replace}
  342. if len(self.linear_layer_setting) == 2:
  343. self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding})
  344. else:
  345. import transformers
  346. if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block:
  347. try:
  348. self.conv_linear_layer = True
  349. self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace}
  350. except ImportError:
  351. self.linear_policies = {nn.Linear: self._replace}
  352. else:
  353. self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding}
  354. def _replace_module(self, r_module, prev_name='', prev_class_name=''):
  355. for name, child in r_module.named_children():
  356. if prev_class_name == "":
  357. class_name = prev_name
  358. elif prev_name == "":
  359. class_name = prev_class_name
  360. else:
  361. class_name = prev_class_name + '.' + prev_name
  362. checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
  363. if Loading.is_load_module(child) and self.state_dict is not None:
  364. if any(checking_key in item for item in self.state_dict):
  365. Loading.load(child, self.state_dict, checking_key, self.mp_group)
  366. else:
  367. continue
  368. if len(child._buffers) != 0 and self.state_dict is not None:
  369. Loading.load_buffer(child, self.state_dict, checking_key)
  370. if child.__class__ in self.linear_policies:
  371. setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
  372. self.conv_linear_layer))
  373. elif any(isinstance(child, lp) for lp in self.linear_policies):
  374. # Added for falcon model support
  375. # Note: isinstance will account for class inheritance, child.__class__ does not
  376. key = None
  377. for lp in self.linear_policies:
  378. if isinstance(child, lp):
  379. key = lp
  380. break
  381. assert key is not None
  382. setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name,
  383. self.conv_linear_layer))
  384. else:
  385. self.update_mp_params(child)
  386. self._replace_module(child, name, class_name)
  387. return r_module