auto_tp.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Automatic Tensor Parallelism
  5. import re
  6. from torch import nn
  7. from .replace_policy import replace_policies
  8. from typing import Optional
  9. import torch
  10. from deepspeed import comm as dist
  11. from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
  12. from deepspeed.accelerator import get_accelerator
  13. from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
  14. from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
  15. def move(tensor, device):
  16. if tensor.is_meta:
  17. return torch.empty_like(tensor, device=device)
  18. else:
  19. # Using new tensors help in freeing memory (after split for example) was done before by calling clone().
  20. # Using copy=True instead of clone() will help in case of cpu --> cpu.
  21. # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
  22. return tensor.to(device, copy=True)
  23. class ReplaceWithTensorSlicing:
  24. def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
  25. if mp_group is not None:
  26. self.gpu_index = dist.get_rank(group=mp_group)
  27. else:
  28. self.gpu_index = 0
  29. self.out_dim = out_dim
  30. self.in_dim = in_dim
  31. self.mp_size = mp_size
  32. def merge_assert(self, dim1, dim2):
  33. assert dim1 > dim2, \
  34. 'Merging tensors is not allowed here! Please use deepspeed load_checkpoint\
  35. for merging your checkpoints before replacing the transformer layer with\
  36. inference-kernels'
  37. def strided_copy(self,
  38. dst: Optional[torch.Tensor],
  39. src: Optional[torch.Tensor],
  40. num_splits: int,
  41. int8: bool = False,
  42. allocate_tensor: bool = False):
  43. if src is None:
  44. return src
  45. src_shape = src.shape
  46. dst_shape = dst.shape
  47. outer_dim = 0 if int8 else -1
  48. if allocate_tensor:
  49. dst = torch.empty_like(dst)
  50. src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
  51. if (len(src_shape) == 2 and len(dst_shape) == 2):
  52. if src_shape[outer_dim] == dst_shape[self.out_dim]:
  53. try:
  54. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  55. except:
  56. print(dst.shape, src.shape)
  57. exit()
  58. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  59. if hasattr(src, 'scale'):
  60. dst.scale = src.scale
  61. return dst
  62. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  63. qkv_size = dst_shape[self.out_dim] // num_splits
  64. qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split]
  65. weight_split = [
  66. torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0]))
  67. ]
  68. dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
  69. weight_split[self.gpu_index].shape)
  70. else:
  71. if src_shape[0] == dst_shape[0]:
  72. return torch.nn.parameter.Parameter(src)
  73. qkv_size = dst_shape[0] // num_splits
  74. qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split]
  75. bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))]
  76. dst.data.copy_(bias_split[self.gpu_index].contiguous())
  77. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  78. if hasattr(src, 'scale'):
  79. dst.scale = src.scale
  80. return dst
  81. def copy(self, dst, src, int8=False, allocate_tensor=False):
  82. if src is None:
  83. return src
  84. assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
  85. if allocate_tensor:
  86. dst = torch.empty_like(dst)
  87. outer_dim = 0 if int8 else 1
  88. inner_dim = 1 if int8 else 0
  89. src_shape = src.shape
  90. dst_shape = dst.shape
  91. if (len(src_shape) == 2 and len(dst_shape) == 2):
  92. if src_shape[inner_dim] == dst_shape[self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
  93. dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
  94. else:
  95. if src_shape[inner_dim] != dst_shape[self.in_dim]:
  96. self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
  97. dst.data.copy_(src[:, self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim]] if inner_dim == 1 else \
  98. src[self.gpu_index * dst_shape[self.in_dim]: (self.gpu_index + 1) * dst_shape[self.in_dim], :])
  99. else:
  100. self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
  101. dst.data.copy_(src[:, self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim]] if outer_dim == 1 else \
  102. src[self.gpu_index * dst_shape[self.out_dim]: (self.gpu_index + 1) * dst_shape[self.out_dim], :])
  103. else:
  104. if src_shape[0] == dst_shape[0]:
  105. dst = src if src.dtype == dst.dtype else dst.data.copy_(src)
  106. else:
  107. dst.data.copy_(src[self.gpu_index * dst_shape[-1]:(self.gpu_index + 1) * dst_shape[-1]])
  108. dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
  109. if hasattr(src, 'scale'):
  110. dst.scale = src.scale
  111. return dst
  112. class Loading():
  113. def is_load_module(module):
  114. load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
  115. load_layer_names = [
  116. "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear"
  117. ]
  118. return module.__class__ in load_layers or module._get_name() in load_layer_names
  119. def load_buffer(module, state_dict, prefix):
  120. for name in module._buffers.keys():
  121. if module._buffers[name].data.is_meta:
  122. module._buffers[name] = torch.nn.parameter.Parameter(
  123. data=torch.empty_like(module._buffers[name].data, device="cpu"),
  124. requires_grad=module._buffers[name].data.requires_grad)
  125. if prefix + name in state_dict.keys():
  126. module._buffers[name].data.copy_(state_dict[prefix + name])
  127. def load(module, state_dict, prefix, mp_group=None):
  128. mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
  129. if hasattr(module, 'weight'):
  130. if module.weight.data.is_meta:
  131. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  132. module.weight = torch.nn.parameter.Parameter(data=torch.empty_like(module.weight.data, device="cpu"),
  133. requires_grad=module.weight.data.requires_grad)
  134. if 'query_key_value' in prefix:
  135. module.weight = mp_replace.strided_copy(module.weight.data,
  136. state_dict[prefix + 'weight'],
  137. num_splits=3)
  138. else:
  139. module.weight = mp_replace.copy(module.weight.data, state_dict[prefix + 'weight'])
  140. else:
  141. if hasattr(module, 'norm') and hasattr(module.norm, 'weight'):
  142. if module.norm.weight.data.is_meta:
  143. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  144. module.norm.weight = torch.nn.parameter.Parameter(
  145. data=torch.empty_like(module.norm.weight.data, device="cpu"),
  146. requires_grad=module.norm.weight.data.requires_grad)
  147. module.norm.weight = mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight'])
  148. if prefix + 'bias' in state_dict.keys():
  149. if hasattr(module, 'bias'):
  150. if module.bias.data.is_meta:
  151. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  152. module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
  153. requires_grad=module.bias.data.requires_grad)
  154. module.bias = mp_replace.copy(module.bias, state_dict[prefix + 'bias'])
  155. else:
  156. if hasattr(module, 'norm') and hasattr(module.norm, 'bias'):
  157. if module.norm.bias.data.is_meta:
  158. # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
  159. module.norm.bias = torch.nn.parameter.Parameter(
  160. data=torch.empty_like(module.norm.bias.data, device="cpu"),
  161. requires_grad=module.norm.bias.data.requires_grad)
  162. module.norm.bias = mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias'])
  163. class AutoTP():
  164. def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
  165. self.module = module
  166. self.all_reduce_linears = all_reduce_linears
  167. self.prefix = prefix
  168. self.state_dict = state_dict
  169. self.mp_size = None
  170. self.mp_group = None
  171. self.linear_layer_setting = linear_layer_setting
  172. self.orig_layer_impl = orig_layer_impl
  173. self.linear_policies = None
  174. self.conv_linear_layer = False
  175. def in_module_list(module, module_list):
  176. for item in module_list:
  177. if type(item).__name__ == type(module).__name__:
  178. return True
  179. return False
  180. def get_module_list(model):
  181. mlist = []
  182. for child in model.children():
  183. if isinstance(child, nn.ModuleList):
  184. for module in child.children():
  185. if not mlist:
  186. mlist = [module]
  187. elif not AutoTP.in_module_list(module, mlist):
  188. mlist = mlist + [module]
  189. else:
  190. mlist = mlist + AutoTP.get_module_list(child)
  191. return mlist
  192. def supported(model):
  193. unsupported = ['deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet']
  194. model = str(model)
  195. key = re.search(r": (.*?)Model", model)
  196. if key is None:
  197. key = re.search(r": (.*?)Stack", model)
  198. if key is None:
  199. key = re.match(r"(.*?)Model", model)
  200. assert key is not None, "Not able to determine model policy automatically. Please provide policy."
  201. if key.group(1).lower() in unsupported:
  202. return False
  203. return True
  204. def get_layers(parent, module):
  205. layer_list = []
  206. for key, submodule in module._modules.items():
  207. if isinstance(submodule, nn.Linear):
  208. layer_list = layer_list + [parent + "." + key]
  209. elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  210. layer_list = layer_list + ["ln"]
  211. else:
  212. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  213. return layer_list
  214. def update_policy_list(policy_list, new_module, new_gems):
  215. if len(policy_list):
  216. for i, policy in enumerate(policy_list):
  217. # if module already exists in policy, combine gems and remove duplicates
  218. if policy[0] == type(new_module):
  219. new_gems = set(new_gems + policy[1])
  220. policy_list[i] = tuple([type(new_module), new_gems])
  221. return policy_list
  222. policy_list.append(tuple([type(new_module), new_gems]))
  223. return policy_list
  224. def kernel_supported(module_list):
  225. policy = []
  226. for plcy in replace_policies:
  227. # instantiate a throw-away policy in order to populate the _orig_layer_class
  228. _ = plcy(None)
  229. if isinstance(plcy._orig_layer_class, list):
  230. for orig_layer_class in plcy._orig_layer_class:
  231. policy.append(orig_layer_class)
  232. elif plcy._orig_layer_class is not None:
  233. policy.append(plcy._orig_layer_class)
  234. for child in module_list:
  235. if child.__class__ in policy:
  236. return True
  237. return False
  238. def tp_parser(model):
  239. policy_list = []
  240. module_list = []
  241. layer_list = []
  242. gem_list = []
  243. module_list = AutoTP.get_module_list(model)
  244. assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  245. if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
  246. for module in module_list:
  247. for key, submodule in module._modules.items():
  248. if isinstance(submodule, nn.Linear):
  249. layer_list = layer_list + ["." + key]
  250. elif isinstance(submodule, nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  251. layer_list = layer_list + ["ln"]
  252. else:
  253. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  254. for i, layer in enumerate(layer_list):
  255. if layer == 'ln':
  256. if layer_list[i - 1] != 'ln':
  257. gem_list = gem_list + [layer_list[i - 1]]
  258. elif 'out_proj' in layer:
  259. gem_list = gem_list + [layer]
  260. elif 'o_proj' in layer:
  261. gem_list = gem_list + [layer]
  262. elif 'down_proj' in layer:
  263. gem_list = gem_list + [layer]
  264. elif 'attention.dense' in layer and 'GPTNeoX' in str(model):
  265. gem_list = gem_list + [layer]
  266. elif 'self_attention.dense' in layer and 'falcon' in str(
  267. type(module)): # this is a hack to get the right linear layer for this model!
  268. gem_list = gem_list + [layer]
  269. layer_list = []
  270. if gem_list != []:
  271. gem_list = list(set(gem_list))
  272. policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
  273. gem_list = []
  274. assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  275. if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
  276. return policy_list
  277. def set_tensor_parallel_config(self, mp_size, mp_group):
  278. self.mp_size = mp_size
  279. self.mp_group = mp_group
  280. def _replace(self, child, name, conv_linear_layer):
  281. if getattr(child, "replaced", False) == True:
  282. return
  283. weight_shape = child.weight.shape
  284. mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
  285. if name in self.all_reduce_linears:
  286. # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
  287. # else [weight_shape[0], weight_shape[1] // mp_size]
  288. if self.conv_linear_layer:
  289. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  290. data = child.weight.data.split(get_shard_size_list(
  291. weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
  292. dim=1)
  293. data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
  294. del data
  295. setattr(child, "replaced", True)
  296. if name == "lm_head" or name == 'embed_out':
  297. return LmHeadLinearAllreduce(
  298. torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
  299. child.bias if child.bias is None else torch.nn.parameter.Parameter(
  300. move(child.bias,
  301. get_accelerator().current_device_name())), self.mp_group)
  302. return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
  303. torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
  304. else:
  305. # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
  306. # else [weight_shape[0] // mp_size, weight_shape[1]]
  307. if self.conv_linear_layer:
  308. child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
  309. if require_tp_fused_qkvw(name, self.mp_size):
  310. #for detecting fused type
  311. module_str = str(self.module).strip()
  312. #The copy is a regular copy, The shape of dst and src is the same
  313. data_dc = move(
  314. prepare_tp_fused_qkvw(module_str, child.weight.data, self.mp_size, mp_replace.gpu_index),
  315. get_accelerator().current_device_name())
  316. bias_data_dc = None if child.bias is None else move(
  317. prepare_tp_fused_qkvw(module_str, child.bias.data, self.mp_size, mp_replace.gpu_index),
  318. get_accelerator().current_device_name())
  319. else:
  320. data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
  321. dim=1 if self.conv_linear_layer else 0)
  322. data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
  323. del data
  324. if child.bias is not None:
  325. bias_data = child.bias.data.split(get_shard_size_list(
  326. weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
  327. dim=0)
  328. bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
  329. bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
  330. del bias_data
  331. else:
  332. bias_data_dc = None
  333. setattr(child, "replaced", True)
  334. return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
  335. def _slice_embedding(self, child, name, conv_linear_layer):
  336. if getattr(child, "replaced", False) == True:
  337. return
  338. mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
  339. if hasattr(child.weight, 'ds_tensor'):
  340. data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
  341. else:
  342. data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
  343. data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
  344. data = torch.nn.parameter.Parameter(data, requires_grad=False)
  345. new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size))
  346. new_embedding.weight.data.copy_(data)
  347. setattr(child, "replaced", True)
  348. return new_embedding
  349. def update_mp_params(self, child):
  350. if getattr(child, "replaced", False) == True:
  351. return
  352. for param in [
  353. "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
  354. "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads",
  355. "d_model"
  356. ]:
  357. if hasattr(child, param):
  358. param_val = getattr(child, param)
  359. setattr(child, param, get_shard_size(param_val, self.mp_size))
  360. setattr(child, "replaced", True)
  361. def update_linear_policies(self):
  362. self.conv_linear_layer = False
  363. if self.linear_layer_setting is not None:
  364. self.linear_policies = {self.linear_layer_setting[0]: self._replace}
  365. if len(self.linear_layer_setting) == 2:
  366. self.linear_policies.update({self.linear_layer_setting[1]: self._slice_embedding})
  367. else:
  368. import transformers
  369. if self.orig_layer_impl is transformers.models.gpt2.modeling_gpt2.GPT2Block:
  370. try:
  371. self.conv_linear_layer = True
  372. self.linear_policies = {transformers.pytorch_utils.Conv1D: self._replace}
  373. except ImportError:
  374. self.linear_policies = {nn.Linear: self._replace}
  375. else:
  376. self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding}
  377. def _replace_module(self, r_module, prev_name='', prev_class_name=''):
  378. for name, child in r_module.named_children():
  379. if prev_class_name == "":
  380. class_name = prev_name
  381. elif prev_name == "":
  382. class_name = prev_class_name
  383. else:
  384. class_name = prev_class_name + '.' + prev_name
  385. checking_key = self.prefix + '.' + class_name + '.' + name + '.' if class_name != "" else self.prefix + '.' + name + '.'
  386. if Loading.is_load_module(child) and self.state_dict is not None:
  387. if any(checking_key in item for item in self.state_dict):
  388. Loading.load(child, self.state_dict, checking_key, self.mp_group)
  389. else:
  390. continue
  391. if len(child._buffers) != 0 and self.state_dict is not None:
  392. Loading.load_buffer(child, self.state_dict, checking_key)
  393. if child.__class__ in self.linear_policies:
  394. setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
  395. self.conv_linear_layer))
  396. elif any(isinstance(child, lp) for lp in self.linear_policies):
  397. # Added for falcon model support
  398. # Note: isinstance will account for class inheritance, child.__class__ does not
  399. key = None
  400. for lp in self.linear_policies:
  401. if isinstance(child, lp):
  402. key = lp
  403. break
  404. assert key is not None
  405. setattr(r_module, name, self.linear_policies[key](child, prev_name + '.' + name,
  406. self.conv_linear_layer))
  407. else:
  408. self.update_mp_params(child)
  409. self._replace_module(child, name, class_name)
  410. return r_module
  411. def get_model_num_kv_heads(self, config):
  412. num_kv_heads = None
  413. kv_head_names = ['num_kv_heads', 'num_key_value_heads', 'num_attention_heads', 'n_heads']
  414. for name in kv_head_names:
  415. if hasattr(config, name):
  416. num_kv_heads = getattr(config, name)
  417. if num_kv_heads is not None:
  418. break
  419. return num_kv_heads
  420. def _replace_last_linear_module(self, r_module):
  421. if hasattr(r_module, "lm_head"):
  422. name = "lm_head"
  423. child = r_module.lm_head
  424. elif hasattr(r_module, "embed_out"):
  425. name = "embed_out"
  426. child = r_module.embed_out
  427. else:
  428. return r_module
  429. if child.__class__ in self.linear_policies:
  430. setattr(r_module, name, self.linear_policies[child.__class__](child, name, self.conv_linear_layer))
  431. return r_module