engine.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. '''
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. '''
  4. import torch
  5. import os
  6. from torch.nn.modules import Module
  7. import torch.distributed as dist
  8. from ..runtime.state_dict_factory import SDLoaderFactory
  9. from ..runtime.weight_quantizer import WeightQuantization
  10. from ..module_inject.replace_module import replace_transformer_layer
  11. from ..utils import logger, init_distributed
  12. from ..pipe import PipelineModule
  13. class InferenceEngine(Module):
  14. inference_mp_group = None
  15. def __init__(self,
  16. model,
  17. mp_size=1,
  18. mpu=None,
  19. checkpoint=None,
  20. dtype=None,
  21. injection_dict=None,
  22. return_tuple=True,
  23. replace_method='auto',
  24. quantization_setting=None,
  25. replace_with_kernel_inject=False):
  26. """
  27. Args:
  28. model: torch.nn.Module
  29. mp_size: model-parallel size
  30. mpu: model-parallel unit (used for Megatron-type models)
  31. checkpoint: the json-path, showing the address of model-checkpoints
  32. Example: {type: 'Megatron', 'checkpoints': [ckpt_mp0.pt, ckpt_mp1.pt], 'version': 1.0}
  33. dtype: data-type by which inference is executed
  34. injection_dict: the dictionary that shows the injection policy:
  35. Example: {BertLayer: HFBertLayerPolicy}
  36. return_tuple: if true, inference-API returns a tuple, otherwise a tensor
  37. replace_method: the injection method, this can be passed as auto if no injection-policy is defined, in which case the injection is automatic based on the available policies
  38. quantization_setting:
  39. one of None, Tuple(mlp_extra_grouping, quantize_groups), quantize_groups
  40. """
  41. super().__init__()
  42. self.module = model
  43. self._get_model_config_generate()
  44. self.mp_world_size = mp_size
  45. self.checkpoint = checkpoint
  46. self.dtype = dtype
  47. self.injection_dict = injection_dict
  48. self.mp_group = None
  49. self.mpu = mpu
  50. self._validate_args(mpu)
  51. self.replace_method = replace_method
  52. self.quantize_merge_count = 1
  53. self.quantization_scales = None
  54. self._init_quantization_setting(quantization_setting)
  55. if self.checkpoint:
  56. self._load_checkpoint(self.checkpoint)
  57. # convert model to intended dtype
  58. if self.dtype:
  59. self._convert_to_dtype()
  60. if self.mpu:
  61. self.mp_world_size = dist.get_world_size(
  62. group=self.mpu.get_model_parallel_group())
  63. self.mp_group = self.mpu.get_model_parallel_group()
  64. elif self.mp_world_size > 1:
  65. self._create_model_parallel_group()
  66. # apply injection policy
  67. if self.injection_dict:
  68. for client_module, injection_policy in self.injection_dict.items():
  69. self._apply_injection_policy(client_module,
  70. injection_policy,
  71. return_tuple,
  72. replace_with_kernel_inject)
  73. elif replace_method == 'auto':
  74. self._apply_injection_policy(
  75. return_tuple=return_tuple,
  76. replace_with_kernel_inject=replace_with_kernel_inject)
  77. device = torch.cuda.current_device()
  78. logger.info(f"Place model to device: {device}")
  79. self.module.to(device)
  80. if self.mp_world_size > 1:
  81. self.model_orig_fwd = self.module.forward
  82. self.module.forward = self.forward
  83. else:
  84. self.module.register_forward_pre_hook(self._pre_forward_hook)
  85. def _get_model_config_generate(self):
  86. self.config = getattr(self.module, 'config', None)
  87. self.generate = getattr(self.module, 'generate', None)
  88. def _create_model_parallel_group(self):
  89. # Call the init process
  90. if InferenceEngine.inference_mp_group is None:
  91. init_distributed()
  92. local_rank = int(os.getenv('LOCAL_RANK', '0'))
  93. torch.cuda.set_device(local_rank)
  94. ranks = [i for i in range(self.mp_world_size)]
  95. self.mp_group = dist.new_group(ranks)
  96. InferenceEngine.inference_mp_group = self.mp_group
  97. else:
  98. self.mp_group = InferenceEngine.inference_mp_group
  99. def _init_quantization_setting(self, quantization_setting):
  100. self.quantize_bits = 8
  101. self.mlp_extra_grouping = False
  102. self.quantize_groups = 1
  103. if type(quantization_setting) is tuple:
  104. self.mlp_extra_grouping, \
  105. self.quantize_groups = quantization_setting
  106. elif quantization_setting is not None:
  107. self.quantize_groups = quantization_setting
  108. logger.info(f"quantize_bits = {self.quantize_bits} "
  109. f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
  110. f"quantize_groups = {self.quantize_groups}")
  111. def _validate_args(self, mpu):
  112. if not isinstance(self.module, Module):
  113. raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
  114. if not isinstance(self.mp_world_size, int) or self.mp_world_size < 1:
  115. raise ValueError(f"mp_size must be an int >= 1, got {self.mp_world_size}")
  116. if mpu:
  117. methods = ["get_model_parallel_group", "get_data_parallel_group"]
  118. for method in methods:
  119. if not hasattr(mpu, method):
  120. raise ValueError(f"mpu is missing {method}")
  121. if self.checkpoint is not None and not isinstance(self.checkpoint, str):
  122. raise ValueError(
  123. f"checkpoint must be None or a str, got {type(self.checkpoint)}")
  124. supported_dtypes = [None, torch.half, torch.int8, torch.float]
  125. if self.dtype not in supported_dtypes:
  126. raise ValueError(
  127. f"{self.dtype} not supported, valid dtype: {supported_dtypes}")
  128. if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
  129. raise ValueError(
  130. f"injection_dict must be None or a dict, got: {self.injection_dict}")
  131. def _apply_injection_policy(self,
  132. client_module=None,
  133. injection_policy=None,
  134. return_tuple=True,
  135. replace_with_kernel_inject=False):
  136. replace_transformer_layer(client_module,
  137. self.module,
  138. policy=injection_policy,
  139. mp_size=self.mp_world_size,
  140. mp_group=self.mp_group,
  141. config=self.config,
  142. fp16=(self.dtype == torch.half),
  143. training=False,
  144. return_tuple=return_tuple,
  145. quantize=(self.dtype == torch.int8),
  146. quantize_settings=(self.quantization_scales,
  147. self.quantize_merge_count,
  148. self.mlp_extra_grouping,
  149. self.quantize_groups),
  150. replace_with_kernel_inject=replace_with_kernel_inject)
  151. def _load_checkpoint(self, load_dir, load_module_strict=True):
  152. sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)
  153. is_pipe_parallel = isinstance(self.module, PipelineModule)
  154. if is_pipe_parallel:
  155. raise RuntimeError(
  156. 'pipeline parallelism is currently not supported in inference.')
  157. mp_rank = 0 if self.mp_group is None else dist.get_rank(group=self.mp_group)
  158. load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size,
  159. mp_rank,
  160. is_pipe_parallel=is_pipe_parallel,
  161. quantize=(self.dtype is torch.int8),
  162. quantize_groups=self.quantize_groups,
  163. mlp_extra_grouping=self.mlp_extra_grouping)
  164. self.quantization_scales, self.quantize_merge_count = quantize_config
  165. if is_pipe_parallel:
  166. # Pipeline parallelism uses this to load its own checkpoint files.
  167. self._curr_ckpt_path = load_dir
  168. self.module.load_state_dict(state_dict=checkpoint['model'],
  169. strict=load_module_strict)
  170. def _convert_to_dtype(self):
  171. if self.dtype is torch.int8 and self.quantization_scales is None:
  172. quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping)
  173. model, self.quantization_scales = quantizer.model_quantize(self.module,
  174. self.injection_dict,
  175. self.quantize_bits,
  176. self.quantize_groups)
  177. elif self.dtype == torch.half:
  178. self.module.half()
  179. elif self.dtype == torch.float:
  180. self.module.float()
  181. def _pre_forward_hook(self, module, *inputs, **kwargs):
  182. for input in inputs:
  183. if torch.is_tensor(input):
  184. input = input.to(torch.cuda.current_device())
  185. for k in kwargs:
  186. if torch.is_tensor(kwargs[k]):
  187. kwargs[k] = kwargs[k].to(torch.cuda.current_device())
  188. def forward(self, *inputs, **kwargs):
  189. """Execute forward propagation
  190. Arguments:
  191. *inputs: Variable length input list
  192. **kwargs: variable length keyword arguments
  193. """
  194. if self.mp_world_size > 1:
  195. if self.mpu is None:
  196. for input in inputs:
  197. if torch.is_tensor(input):
  198. input = input.to(torch.cuda.current_device())
  199. if not input.is_contiguous():
  200. input = input.contiguous()
  201. for k in kwargs:
  202. if torch.is_tensor(kwargs[k]):
  203. kwargs[k] = kwargs[k].to(torch.cuda.current_device())
  204. if not kwargs[k].is_contiguous():
  205. kwargs[k] = kwargs[k].contiguous()
  206. dist.broadcast(kwargs[k], 0)
  207. outputs = self.model_orig_fwd(*inputs, **kwargs)
  208. else:
  209. outputs = self.module(*inputs, **kwargs)
  210. return outputs