''' Copyright 2021 The Microsoft DeepSpeed Team ''' import torch import time import os from deepspeed import comm as dist from deepspeed.utils.logging import log_dist from torch.nn.modules import Module from packaging import version as pkg_version from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.utils.timer import SynchronizedWallClockTimer from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization from ..module_inject import replace_transformer_layer, generic_injection from ..comm.comm import init_distributed from ..pipe import PipelineModule from ..moe.utils import has_moe_layers from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing from deepspeed.accelerator import get_accelerator from ..module_inject.policy import TransformerPolicy from ..module_inject.auto_tp import AutoTP from ..module_inject.replace_policy import generic_policies DS_INFERENCE_ENABLED = False from torch import nn INFERENCE_MODEL_TIMER = "model-forward-inference" class InferenceEngine(Module): inference_mp_group = None inference_ep_group = None expert_mp_group = None def __init__(self, model, config): """ Args: model: torch.nn.Module config: DeepSpeedInferenceConfig """ global DS_INFERENCE_ENABLED DS_INFERENCE_ENABLED = True super().__init__() self.module = model self._config = config self._get_model_config_generate(config) # keep for weird backward compatibility # patch model generate with ours if model uses it if hasattr(self.module, "generate"): self.generate = self._generate if hasattr(self.module, "config"): TransformerPolicy.hf_model_config = self.module.config # todo: keep this self.injection_dict because we don't use to change config.injection_policy API # todo: this will get changed when Molly's PR on auto injection dict is merged self.injection_dict = config.injection_policy # todo: refactor the mp_group and mp_size related in the next refactor self.mp_group = config.tensor_parallel.tp_group self.mpu = config.tensor_parallel.mpu #self._validate_args(self.mpu, config.replace_with_kernel_inject) self.quantize_merge_count = 1 self.quantization_scales = None # these are not needed in the config as we are creating them ourselves in the inference engine self.ep_group = None # config.moe.ep_group self.expert_mp_group = None # config.moe.ep_mp_group self.cuda_graph_created = False self.checkpoint_engine = TorchCheckpointEngine() quantization_setting = None self._init_quantization_setting( quantization_setting ) # todo: update with the new quant config for weight quant self.model_profile_enabled = False self._model_times = [] # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture self.remove_mask_prepare_for_bloom() if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph: assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ "If you want to use cuda graph, please upgrade torch to at least v1.10" if config.checkpoint and not config.replace_with_kernel_inject: self._load_checkpoint(config.checkpoint) # convert model to intended dtype if config.dtype: self._convert_to_dtype(config) if self.mpu: config.tensor_parallel.tp_size = dist.get_world_size( group=self.mpu.get_model_parallel_group()) self.mp_group = self.mpu.get_model_parallel_group() elif config.tensor_parallel.tp_size > 1: self._create_model_parallel_group(config) config.tensor_parallel.tp_group = self.mp_group if isinstance(self.module, torch.nn.Module): moe, _ = has_moe_layers(self.module) else: moe = False if moe and dist.get_world_size() > 1: self._create_ep_parallel_group(config.moe.moe_experts) # retain this from the old conditional argument being passed to apply_injection_policy() if not config.replace_with_kernel_inject: config.checkpoint = None # We only support three modes: 1) user specified policy for tensor-parallelism, 2) kernel injection (replace_with_kernel_inject), and 3) automatic tensor parallelism. if self.injection_dict: # 1. User specified Tensor Parallelism assert not config.replace_with_kernel_inject, "Cannot use both user specified injection policy and kernel injection" for client_module, injection_policy in self.injection_dict.items(): # construct the tuple and pass that instead of a string or dict. if isinstance(injection_policy, str): config.injection_policy_tuple = (injection_policy, ) else: config.injection_policy_tuple = injection_policy self._apply_injection_policy(config, client_module) else: if config.replace_with_kernel_inject: # 2. DeepSpeed Kernel Injection self._apply_injection_policy(config) else: # 3. Automatic Tensor Parallelism parser_dict = AutoTP.tp_parser(model) print("AutoTP: ", parser_dict) for client_module, injection_policy in parser_dict: if isinstance(injection_policy, str): config.injection_policy_tuple = (injection_policy, ) else: config.injection_policy_tuple = injection_policy self._apply_injection_policy(config, client_module) device = get_accelerator().current_device_name() self.module.to(device) if config.tensor_parallel.tp_size > 1: _rng_state = get_accelerator().get_rng_state().to( get_accelerator().current_device_name()) dist.broadcast(_rng_state, 0) get_accelerator().set_rng_state(_rng_state.cpu()) if config.tensor_parallel.tp_size > 1: assert not config.enable_cuda_graph, "Cuda graph is not supported for model parallelism" # Check if local CUDA graphs can be created in replacement modules self.local_cuda_graph = self._local_cuda_graph_used(self.module) def profile_model_time(self, use_cuda_events=True): if not self.model_profile_enabled and not self._config.enable_cuda_graph: self.module.register_forward_pre_hook(self._pre_forward_hook) self.module.register_forward_hook(self._post_forward_hook) self.model_profile_enabled = True self.use_cuda_events = use_cuda_events if self.use_cuda_events: self.timers = SynchronizedWallClockTimer() # todo: remove this once all the config dicts are centralized from top level pydantic config def _get_model_config_generate(self, config): # this is being passed to replace_transformer_layer(config=self.user_model_config_dict) self.config = getattr(self.module, 'config', None) if config.config is None else config.config def remove_mask_prepare_for_bloom(self): if hasattr(self.module, 'transformer'): if hasattr(self.module.transformer, '_prepare_attn_mask'): self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask def _pre_forward_hook(self, module, *inputs, **kwargs): if self.use_cuda_events: self.timers(INFERENCE_MODEL_TIMER).start() else: get_accelerator().synchronize() self._start = time.time() def _post_forward_hook(self, module, input, output): if self.use_cuda_events: self.timers(INFERENCE_MODEL_TIMER).stop() elapsed_time = self.timers(INFERENCE_MODEL_TIMER).elapsed(reset=True) else: get_accelerator().synchronize() self._end = time.time() elapsed_time = self._end - self._start self._model_times.append(elapsed_time) def _create_model_parallel_group(self, config): # Call the init process if InferenceEngine.inference_mp_group is None: init_distributed() local_rank = int(os.getenv('LOCAL_RANK', '0')) get_accelerator().set_device(local_rank) ranks = [i for i in range(config.tensor_parallel.tp_size)] self.mp_group = dist.new_group(ranks) InferenceEngine.inference_mp_group = self.mp_group else: self.mp_group = InferenceEngine.inference_mp_group def _create_ep_parallel_group(self, moe_experts): # Call the init process self.ep_group = {} self.expert_mp_group = {} moe_experts = moe_experts if type(moe_experts) is list else [moe_experts] for e in moe_experts: self.ep_group.update({e: None}) self.expert_mp_group.update({e: None}) for moe_ep_size in self.ep_group.keys(): num_ep_groups = dist.get_world_size() // moe_ep_size for i in range(num_ep_groups): ep_cnt = i * moe_ep_size size = dist.get_world_size( ) if moe_ep_size > dist.get_world_size() else moe_ep_size ranks = list(range(ep_cnt, ep_cnt + size)) _ep_group = dist.new_group(ranks) if dist.get_rank() in ranks: self.ep_group.update({moe_ep_size: _ep_group}) if dist.get_world_size() > moe_ep_size: num_expert_mp_groups = dist.get_world_size() // num_ep_groups expert_mp_size = dist.get_world_size() // moe_ep_size for i in range(num_expert_mp_groups): expert_mp_comm_ranks = [ i + nr * moe_ep_size for nr in range(expert_mp_size) ] _expert_mp_group = dist.new_group(expert_mp_comm_ranks) if dist.get_rank() in expert_mp_comm_ranks: self.expert_mp_group.update({moe_ep_size: _expert_mp_group}) def _init_quantization_setting(self, quantization_setting): self.quantize_bits = 8 self.mlp_extra_grouping = False self.quantize_groups = 1 if type(quantization_setting) is tuple: self.mlp_extra_grouping, \ self.quantize_groups = quantization_setting elif quantization_setting is not None: self.quantize_groups = quantization_setting log_dist( f"quantize_bits = {self.quantize_bits} " f"mlp_extra_grouping = {self.mlp_extra_grouping}, " f"quantize_groups = {self.quantize_groups}", [0]) # TODO: remove this function and add this functionality to pydantic config checking def _validate_args(self, mpu, replace_with_kernel_inject): # TODO: to support SD pipeline we need to avoid this check for now if replace_with_kernel_inject and not isinstance(self.module, Module): raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}") if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1: raise ValueError( f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}" ) if mpu: methods = ["get_model_parallel_group", "get_data_parallel_group"] for method in methods: if not hasattr(mpu, method): raise ValueError(f"mpu is missing {method}") if self._config.checkpoint is not None and not isinstance( self._config.checkpoint, (str, dict)): raise ValueError( f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}" ) supported_dtypes = [None, torch.half, torch.int8, torch.float] if self._config.dtype not in supported_dtypes: raise ValueError( f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}") if self.injection_dict is not None and not isinstance(self.injection_dict, dict): raise ValueError( f"injection_dict must be None or a dict, got: {self.injection_dict}") def load_model_with_checkpoint(self, r_module): self.mp_replace = ReplaceWithTensorSlicing( mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1) error_msgs = [] def load(module, state_dict, prefix): args = (state_dict, prefix, {}, True, [], [], error_msgs) if hasattr(module, 'weight'): if 'query_key_value' in prefix: module.weight = self.mp_replace.qkv_copy( module.weight.data, state_dict[prefix + 'weight']) else: module.weight = self.mp_replace.copy(module.weight.data, state_dict[prefix + 'weight']) else: module.norm.weight = self.mp_replace.copy(module.norm.weight.data, state_dict[prefix + 'weight']) if prefix + 'bias' in self.key_list: if hasattr(module, 'norm'): module.norm.bias = self.mp_replace.copy(module.norm.bias, state_dict[prefix + 'bias']) else: data = state_dict[prefix + 'bias'] data = data.to(get_accelerator().current_device_name()) module.bias = self.mp_replace.copy(module.bias, data) layer_policies = { nn.Linear: load, nn.Embedding: load, nn.LayerNorm: load, LinearLayer: load, LinearAllreduce: load } def load_module_recursive(module, prefix='', level=0): for name, child in module.named_children(): if child.__class__ in layer_policies: checking_key = prefix + name + '.' if not any(checking_key in item for item in self.key_list): continue if len(list(child.parameters())) > 0 and list( child.parameters())[0].numel() == 0: if len(child.weight.ds_shape) == 1: child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) load(child, self.sd, prefix + name + '.') else: load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1) load_module_recursive(r_module) def _apply_injection_policy(self, config, client_module=None): # client_module is only passed when using the injection_dict method. checkpoint_dir = config.checkpoint checkpoint = SDLoaderFactory.get_sd_loader_json( checkpoint_dir, self.checkpoint_engine) if checkpoint_dir is not None else None generic_injection(self.module, fp16=(config.dtype == torch.half) or (config.dtype == torch.int8), enable_cuda_graph=config.enable_cuda_graph) if isinstance(self.module, torch.nn.Module): # config is our DeepSpeedInferenceConfig and self.config is the HF model config replace_transformer_layer(client_module, self.module, checkpoint, config, self.config) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*") import glob ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files.sort() return ckpt_files def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None): if mp_placeholder is not None: mp_rank_str = mp_placeholder else: mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() mp_rank_str = "{:02d}".format(mp_rank) ckpt_name = os.path.join( checkpoints_path, "mp_rank_" + mp_rank_str + "_model_states.pt", ) return ckpt_name def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): is_pipe_parallel = isinstance(self.module, PipelineModule) if is_pipe_parallel: raise RuntimeError( 'pipeline parallelism is currently not supported in inference.') if not isinstance(load_dir, dict) and os.path.isdir(load_dir): if tag is None: latest_path = os.path.join(load_dir, "latest") if os.path.isfile(latest_path): with open(latest_path, "r") as fd: tag = fd.read().strip() ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) else: sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir, self.checkpoint_engine) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) for i in range(1, len(sd_loader)): if not dist.is_initialized() or dist.get_rank() == 0: print(f"loading checkpoint ({i})") self.sd = torch.load(sd_loader[i], map_location=get_accelerator().device_name()) self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) else: mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() load_path, checkpoint, quantize_config = sd_loader.load(self._config.tensor_parallel.tp_size, mp_rank, is_pipe_parallel=is_pipe_parallel, quantize=(self._config.dtype is torch.int8), quantize_groups=self.quantize_groups, mlp_extra_grouping=self.mlp_extra_grouping) self.quantization_scales, self.quantize_merge_count = quantize_config moe, _ = has_moe_layers(self.module) if moe: from deepspeed.runtime.engine import DeepSpeedEngine old_moe_load = False if not isinstance(checkpoint['num_experts'], list): old_moe_load = True DeepSpeedEngine.load_moe_state_dict( load_dir, tag, state_dict=checkpoint[self._choose_module_key(checkpoint)], old_moe_load=old_moe_load, model=self.module, mpu=self.mpu, checkpoint_engine=self.checkpoint_engine) self.module.load_state_dict( state_dict=checkpoint[self._choose_module_key(checkpoint)], strict=load_module_strict) def _choose_module_key(self, sd): assert not ('module' in sd and 'model' in sd), "checkpoint has both 'model' and 'module' keys, not sure how to proceed" assert 'module' in sd or 'model' in sd, "checkpoint contains neither 'model' or 'module' keys, not sure how to proceed" if 'module' in sd: return 'module' elif 'model' in sd: return 'model' def _convert_to_dtype(self, config): if not isinstance(self.module, torch.nn.Module): return if False: #config.dtype is torch.int8 and self.quantization_scales is None: quantizer = WeightQuantization(mlp_extra_grouping=self.mlp_extra_grouping) model, self.quantization_scales = quantizer.model_quantize(self.module, self.injection_dict, self.quantize_bits, self.quantize_groups) elif config.dtype == torch.half: self.module.half() elif config.dtype == torch.bfloat16: self.module.bfloat16() elif config.dtype == torch.float: self.module.float() def _create_cuda_graph(self, *inputs, **kwargs): # warmup to create the workspace and cublas handle cuda_stream = get_accelerator().Stream() cuda_stream.wait_stream(get_accelerator().current_stream()) with get_accelerator().stream(cuda_stream): for i in range(3): ret = self.module(*inputs, **kwargs) get_accelerator().current_stream().wait_stream(cuda_stream) # create cuda_graph and assign static_inputs and static_outputs self._cuda_graphs = torch.cuda.CUDAGraph() self.static_inputs = inputs self.static_kwargs = kwargs with torch.cuda.graph(self._cuda_graphs): self.static_output = self.module(*self.static_inputs, **self.static_kwargs) self.cuda_graph_created = True def _graph_replay(self, *inputs, **kwargs): for i in range(len(inputs)): if torch.is_tensor(inputs[i]): self.static_inputs[i].copy_(inputs[i]) for k in kwargs: if torch.is_tensor(kwargs[k]): self.static_kwargs[k].copy_(kwargs[k]) self._cuda_graphs.replay() return self.static_output def model_times(self): assert self.model_profile_enabled, "model profiling is not enabled" model_times = self._model_times if self._config.enable_cuda_graph and len(self._model_times) == 0: raise ValueError( "Model times are empty and cuda graph is enabled. If " "this is a GPT-style model this combo is not supported. If this is a " "BERT-style model this is a bug, please report it. " f"Model type is: {type(self.module)}") self._model_times = [] return model_times def _module_match(self, module): for policy in generic_policies: policy = policy() if policy.match_replaced(module): return True return False def _local_cuda_graph_used(self, module): if isinstance(module, torch.nn.Module): return False else: sub_module_cuda_graph = False for name in module.__dict__.keys(): sub_module = getattr(module, name) if self._module_match(sub_module) and hasattr(sub_module, "enable_cuda_graph"): sub_module_cuda_graph = True return sub_module_cuda_graph def forward(self, *inputs, **kwargs): """Execute forward propagation Arguments: *inputs: Variable length input list **kwargs: variable length keyword arguments """ start = None if self.model_profile_enabled and get_accelerator().device_name( ) == 'cuda' and self._config.enable_cuda_graph: get_accelerator().synchronize() start = time.time() if get_accelerator().device_name( ) == 'cuda' and self._config.enable_cuda_graph and not self.local_cuda_graph: if self.cuda_graph_created: outputs = self._graph_replay(*inputs, **kwargs) else: self._create_cuda_graph(*inputs, **kwargs) outputs = self._graph_replay(*inputs, **kwargs) else: outputs = self.module(*inputs, **kwargs) if self.model_profile_enabled and self._config.enable_cuda_graph: get_accelerator().synchronize() duration = time.time() - start self._model_times.append(duration) return outputs def _generate(self, *inputs, **kwargs): # Reset KV-cache at the beginning of generate if hasattr(self.module, 'reset_cache'): self.module.reset_cache() num_beams = 1 if "generation_config" in kwargs: gen_config = kwargs["generation_config"] num_beams = getattr(gen_config, "num_beams", 1) if "num_beams" in kwargs: num_beams = kwargs["num_beams"] if num_beams > 1: raise NotImplementedError( "DeepSpeed does not support `num_beams` > 1, if this is important to you please " "add your request to: https://github.com/microsoft/DeepSpeed/issues/2506" ) return self.module.generate(*inputs, **kwargs)