123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- '''Copyright The Microsoft DeepSpeed Team'''
- # Automatic Tensor Parallelism
- import re
- from torch import nn
- from .replace_policy import replace_policies
- class AutoTP():
- def in_module_list(module, module_list):
- for item in module_list:
- if type(item).__name__ == type(module).__name__:
- return True
- return False
- def get_module_list(model):
- mlist = []
- for child in model.children():
- if isinstance(child, nn.ModuleList):
- for module in child.children():
- if not mlist:
- mlist = [module]
- elif not AutoTP.in_module_list(module, mlist):
- mlist = mlist + [module]
- else:
- mlist = mlist + AutoTP.get_module_list(child)
- return mlist
- def supported(model):
- unsupported = [
- 'bloom',
- 'codegen',
- 'deberta',
- 'flaubert',
- 'fsmt',
- 'gpt2',
- 'led',
- 'longformer',
- 'xlm',
- 'xlnet'
- ]
- model = str(model)
- key = re.search(r": (.*?)Model", model)
- if key is None:
- key = re.search(r": (.*?)Stack", model)
- if key is None:
- key = re.match(r"(.*?)Model", model)
- assert key is not None, "Not able to determine model policy automatically. Please provide policy."
- if key.group(1).lower() in unsupported:
- return False
- return True
- def get_layers(parent, module):
- layer_list = []
- for key, submodule in module._modules.items():
- if isinstance(submodule, nn.Linear):
- layer_list = layer_list + [parent + "." + key]
- elif isinstance(submodule,
- nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
- layer_list = layer_list + ["ln"]
- else:
- layer_list = layer_list + AutoTP.get_layers(key, submodule)
- return layer_list
- def update_policy_list(policy_list, new_module, new_gems):
- if len(policy_list):
- for i, policy in enumerate(policy_list):
- # if module already exists in policy, combine gems and remove duplicates
- if policy[0] == type(new_module):
- new_gems = set(new_gems + policy[1])
- policy_list[i] = tuple([type(new_module), new_gems])
- return policy_list
- policy_list.append(tuple([type(new_module), new_gems]))
- return policy_list
- def kernel_supported(module_list):
- policy = []
- for plcy in replace_policies:
- # instantiate a throw-away policy in order to populate the _orig_layer_class
- _ = plcy(None)
- if isinstance(plcy._orig_layer_class, list):
- for orig_layer_class in plcy._orig_layer_class:
- policy.append(orig_layer_class)
- elif plcy._orig_layer_class is not None:
- policy.append(plcy._orig_layer_class)
- for child in module_list:
- if child.__class__ in policy:
- return True
- return False
- def tp_parser(model):
- policy_list = []
- module_list = []
- layer_list = []
- gem_list = []
- module_list = AutoTP.get_module_list(model)
- assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
- if AutoTP.kernel_supported(module_list) else "AutoTP not supported for model. Please provide policy."
- for module in module_list:
- for key, submodule in module._modules.items():
- if isinstance(submodule, nn.Linear):
- layer_list = layer_list + ["." + key]
- elif isinstance(
- submodule,
- nn.LayerNorm) or key == 'LayerNorm' or key == 'layer_norm':
- layer_list = layer_list + ["ln"]
- else:
- layer_list = layer_list + AutoTP.get_layers(key, submodule)
- for i, layer in enumerate(layer_list):
- if layer == 'ln':
- if layer_list[i - 1] != 'ln':
- gem_list = gem_list + [layer_list[i - 1]]
- elif 'out_proj' in layer:
- gem_list = gem_list + [layer]
- layer_list = []
- if gem_list != []:
- gem_list = list(set(gem_list))
- policy_list = AutoTP.update_policy_list(policy_list, module, gem_list)
- gem_list = []
- assert len(policy_list), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \
- if AutoTP.kernel_supported(module_list) else "Not able to determine model policy automatically. Please provide policy."
- return policy_list
|