auto_tp.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. # Automatic Tensor Parallelism
  3. import re
  4. from torch import nn
  5. from .replace_policy import replace_policies
  6. class AutoTP():
  7. def in_module_list(module, module_list):
  8. for item in module_list:
  9. if type(item).__name__ == type(module).__name__:
  10. return True
  11. return False
  12. def get_module_list(model):
  13. mlist = []
  14. for child in model.children():
  15. if isinstance(child, nn.ModuleList):
  16. for module in child.children():
  17. if not mlist:
  18. mlist = [module]
  19. elif not AutoTP.in_module_list(module, mlist):
  20. mlist = mlist + [module]
  21. else:
  22. mlist = mlist + AutoTP.get_module_list(child)
  23. return mlist
  24. def supported(model):
  25. unsupported = [
  26. 'bloom',
  27. 'codegen',
  28. 'deberta',
  29. 'flaubert',
  30. 'fsmt',
  31. 'gpt2',
  32. 'led',
  33. 'longformer',
  34. 'xlm',
  35. 'xlnet'
  36. ]
  37. model = str(model)
  38. key = re.search(r": (.*?)Model", model)
  39. if key is None:
  40. key = re.search(r": (.*?)Stack", model)
  41. if key is None:
  42. key = re.match(r"(.*?)Model", model)
  43. assert key is not None, "Not able to determine model policy automatically. Please provide policy."
  44. if key.group(1).lower() in unsupported:
  45. return False
  46. return True
  47. def get_layers(parent, module):
  48. layer_list = []
  49. for key, submodule in module._modules.items():
  50. if isinstance(submodule, nn.Linear):
  51. layer_list = layer_list + [parent + "." + key]
  52. elif isinstance(submodule,
  53. nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  54. layer_list = layer_list + ["ln"]
  55. else:
  56. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  57. return layer_list
  58. def update_policy_list(policy_list, new_module, new_gems):
  59. if len(policy_list):
  60. for i, policy in enumerate(policy_list):
  61. # if module already exists in policy, combine gems and remove duplicates
  62. if policy[0] == type(new_module):
  63. new_gems = set(new_gems + policy[1])
  64. policy_list[i] = tuple([type(new_module), new_gems])
  65. return policy_list
  66. policy_list.append(tuple([type(new_module), new_gems]))
  67. return policy_list
  68. def kernel_supported(module_list):
  69. policy = []
  70. for plcy in replace_policies:
  71. # instantiate a throw-away policy in order to populate the _orig_layer_class
  72. _ = plcy(None)
  73. if isinstance(plcy._orig_layer_class, list):
  74. for orig_layer_class in plcy._orig_layer_class:
  75. policy.append(orig_layer_class)
  76. elif plcy._orig_layer_class is not None:
  77. policy.append(plcy._orig_layer_class)
  78. for child in module_list:
  79. if child.__class__ in policy:
  80. return True
  81. return False
  82. def tp_parser(model):
  83. policy_list = []
  84. module_list = []
  85. layer_list = []
  86. gem_list = []
  87. module_list = AutoTP.get_module_list(model)
  88. assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  89. if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
  90. for module in module_list:
  91. for key, submodule in module._modules.items():
  92. if isinstance(submodule, nn.Linear):
  93. layer_list = layer_list + ["." + key]
  94. elif isinstance(
  95. submodule,
  96. nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
  97. layer_list = layer_list + ["ln"]
  98. else:
  99. layer_list = layer_list + AutoTP.get_layers(key, submodule)
  100. for i, layer in enumerate(layer_list):
  101. if layer == 'ln':
  102. if layer_list[i - 1] != 'ln':
  103. gem_list = gem_list + [layer_list[i - 1]]
  104. elif 'out_proj' in layer:
  105. gem_list = gem_list + [layer]
  106. layer_list = []
  107. if gem_list != []:
  108. gem_list = list(set(gem_list))
  109. policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
  110. gem_list = []
  111. assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
  112. if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
  113. return policy_list