123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import os
- import torch
- import tqdm
- import deepspeed
- import deepspeed.ops.transformer as transformer_inference
- from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
- from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
- from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
- from deepspeed.accelerator import get_accelerator
- from .replace_policy import replace_policies, generic_policies
- from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
- from deepspeed import comm as dist
- from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd
- from .load_checkpoint import load_model_with_checkpoint
- import time
- from .utils import policy_to_ds_container
- import gc
- def get_transformer_name(replaced_module):
- from .containers import supported_models
- from torch.nn import ModuleList
- transformer_name = ''
- for n, c in replaced_module.named_children():
- if c.__class__ in supported_models:
- transformer_name += n + '.'
- for name, child in c.named_children():
- if child.__class__ is ModuleList:
- transformer_name += name
- break
- break
- return transformer_name
- class GroupQuantizer:
- def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
- self.group_size = group_size
- self.num_bits = num_bits
- self.q_int8 = q_int8
- self.num_groups = num_groups
- def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
- if not self.q_int8 or not qkv:
- inputs = torch.nn.Parameter(inputs, requires_grad=False)
- inputs.scale = torch.empty(1)
- return inputs
- q_range = 2**self.num_bits
- num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[0] // self.group_size
- inputs = inputs.to(get_accelerator().current_device_name())
- input_flat = inputs.reshape(num_groups, -1).contiguous()
- input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
- input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
- scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
- input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
- inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
- out = torch.nn.Parameter(inputs_q, requires_grad=False)
- inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
- input_flat = [inputs_split[i].reshape(num_groups, -1).contiguous() for i in range(2)]
- input_min = [torch.min(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
- input_max = [torch.max(input_flat[i], dim=1, keepdim=True)[0].float() for i in range(2)]
- scale1 = [(torch.max(input_min[i].abs(), input_max[i].abs()) * 2.0 / (q_range)).squeeze().unsqueeze(0)
- for i in range(2)]
- out.scale = torch.cat([scale.squeeze().unsqueeze(0), scale1[0], scale1[1]], dim=0).reshape(num_groups,
- -1).contiguous()
- return out
- def _module_match(module):
- for policy in generic_policies:
- policy = policy()
- if policy.match(module):
- return policy
- return None
- def generic_injection(module, dtype=None, enable_cuda_graph=True):
- def replace_attn(child, policy):
- policy_attn = policy.attention(child)
- if policy_attn is None:
- return child
- if len(policy_attn) == 5:
- qkvw, attn_ow, attn_ob, hidden_size, heads = policy_attn
- else:
- qw, kw, vw, attn_ow, attn_ob, hidden_size, heads = policy_attn
- config = transformer_inference.DeepSpeedInferenceConfig(
- hidden_size=hidden_size,
- heads=heads,
- dtype=dtype,
- triangular_masking=False,
- max_out_tokens=4096,
- )
- attn_module = DeepSpeedDiffusersAttention(config)
- def transpose(data):
- data = data.contiguous()
- data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1))
- data = data.reshape(data.shape[-1], data.shape[-2])
- data.to(get_accelerator().current_device_name())
- return data
- if len(policy_attn) == 5:
- attn_module.attn_qkvw.data = transpose(qkvw.data)
- else:
- attn_module.attn_qkvw = None
- attn_module.attn_qw.data = transpose(qw.data)
- attn_module.attn_kw.data = transpose(kw.data)
- attn_module.attn_vw.data = transpose(vw.data)
- attn_module.attn_qkvb = None
- attn_module.attn_ow.data = transpose(attn_ow.data)
- attn_module.attn_ob.data.copy_(attn_ob.data.to(get_accelerator().current_device_name()))
- return attn_module
- def replace_attn_block(child, policy):
- config = Diffusers2DTransformerConfig()
- return DeepSpeedDiffusersTransformerBlock(child, config)
- if isinstance(module, torch.nn.Module):
- pass
- else:
- if dtype not in [torch.float16, torch.half]:
- raise ValueError("Generic injection only supported with FP16")
- try:
- import diffusers
- if hasattr(diffusers.models.attention, 'CrossAttention'):
- cross_attention = diffusers.models.attention.CrossAttention
- else:
- cross_attention = diffusers.models.attention_processor.Attention
- attention_block = diffusers.models.attention.BasicTransformerBlock
- new_policies = {
- cross_attention: replace_attn,
- attention_block: replace_attn_block,
- }
- except ImportError:
- new_policies = {}
- #replace_transformer_layer(None,
- # module.text_encoder,
- # training=False,
- # replace_with_kernel_inject=True,
- # triangular_masking=True,
- # max_out_tokens=8192)
- from ..model_implementations.transformers.clip_encoder import DSClipEncoder
- cg_encoder = DSClipEncoder(module.text_encoder, enable_cuda_graph=enable_cuda_graph)
- setattr(module, 'text_encoder', cg_encoder)
- for name in module.__dict__.keys():
- sub_module = getattr(module, name)
- policy = _module_match(sub_module)
- if policy is not None:
- def _replace_module(module, policy):
- for name, child in module.named_children():
- _replace_module(child, policy)
- if child.__class__ in new_policies:
- replaced_module = new_policies[child.__class__](child, policy)
- setattr(module, name, replaced_module)
- _replace_module(sub_module, policy)
- new_module = policy.apply(sub_module, enable_cuda_graph=enable_cuda_graph)
- print(f"**** found and replaced {name} w. {type(new_module)}")
- setattr(module, name, new_module)
- container_g = None
- def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, model_config):
- """ Replace bert-style transformer layers with DeepSpeed's transformer layer
- Arguments:
- orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
- e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
- model (torch.nn.Module): user's nn.module representing their model
- checkpoint_dict: Dictionary for checkpoint passed from the Inference Engine
- config: top-level DS Inference config defined in inference/config.py
- model_config: HuggingFace model config passed from the inference/engine.py
- Returns:
- Updated nn.module with replaced transformer layers
- """
- # defining globals as internally defined functions inherit these everywhere
- quantize = (config.dtype == torch.int8)
- # todo: Refactor later. In future, let's minimize the style used above and use config.** instead
- linear_layer_setting = None
- '''
- linear_layer_setting (tuple of modules) [Optional]: shows which two classes are used for linear layers and embedding layers
- '''
- micro_batch_size = -1
- seed = -1
- local_rank = -1
- mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group,
- mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
- def replace_with_policy(child, policy_cls, triangular_masking, inference=False, layer_id=0):
- policy = policy_cls(child, inference=inference)
- if not policy.cuda_graph_supported:
- # policy says cuda graph is not supported raise an error if set
- assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
- from deepspeed.moe.layer import MoE
- moe = False
- if hasattr(child, 'mlp') and isinstance(child.mlp, MoE):
- num_experts = child.mlp.num_experts
- moe = True
- # 1. Create a model-specific container object using the policy object.
- _container = policy_to_ds_container(policy=policy,
- config=config,
- model_config=model_config,
- layer_id=layer_id,
- child=child)
- _container.set_moe(moe)
- # 2. Set the tensor parallelism config
- _container.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
- # 3. Initialize tensors
- _container.initialize_tensors()
- # 4. deal with data types -- needs refactor to use dtype instead of fp16
- if config.dtype in [torch.float16, torch.bfloat16, torch.int8]:
- _container.convert_to_required_dtype()
- # 5. Set the quantization config
- quantizer = GroupQuantizer(q_int8=quantize)
- _container.set_quantization_config(quantizer)
- # 6. create a DS Inference config object
- _container.create_ds_model_config()
- # 7. use the config and create the module
- _container.create_module()
- # 8. transpose the weights and bias if needed
- _container.transpose()
- # 9. deal with tensor parallelism.
- _container.apply_tensor_parallelism(mp_replace)
- # 10. copy the tensors from the model-specific container to the new module
- _container.copy_data_to_new_module()
- # 11. set global for generic checkpoint loading
- global container_g
- if container_g is None:
- container_g = _container
- return _container.module
- def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
- #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
- # 1. Create AutoTP object
- _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
- # 2. Set the tensor parallelism config
- _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
- # 3. Try to get num_key_heads from model_config.num_key_value_heads
- num_kv_heads = _autotp.get_model_num_kv_heads(model_config)
- # 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
- set_num_kv_heads(num_kv_heads)
- # 4.1 Get n_embd
- n_embd = None
- multi_query_n_embd_names = ['n_embd']
- for name in multi_query_n_embd_names:
- if hasattr(model_config, name):
- n_embd = getattr(model_config, name)
- if n_embd != None:
- break
- # 4.2 set n_embd
- set_n_embd(n_embd)
- # 5. Set linear policies
- _autotp.update_linear_policies()
- # 6. Replace modules
- if "lm_head" in all_reduce_linears or "embed_out" in all_reduce_linears:
- return _autotp._replace_last_linear_module(module)
- return _autotp._replace_module(module)
- def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
- training = False # todo: refactor this part to go in the config
- if training:
- # copy relevant state from child -> new module
- new_module = replace_with_policy(child, _policy, config.triangular_masking)
- else:
- # copy relevant state from child -> new module
- if config.replace_with_kernel_inject:
- new_module = replace_with_policy(child,
- _policy,
- config.triangular_masking,
- inference=True,
- layer_id=layer_id)
- else:
- new_module = replace_wo_policy(child, _policy, prefix=prefix, state_dict=state_dict)
- return new_module
- def set_lm_head(module):
- embedding_weight = None
- for n, p in module.named_parameters():
- if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
- embedding_weight = p
- if embedding_weight is not None and hasattr(module, "lm_head") and hasattr(
- module.lm_head, "weight") and module.lm_head.weight.is_meta:
- module.lm_head.weight = embedding_weight
- # enable tensor parallel for the last linear
- if hasattr(module, "lm_head") and hasattr(module.lm_head,
- "weight") and not module.lm_head.weight.is_meta and isinstance(
- module.lm_head, torch.nn.Linear):
- module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
- elif hasattr(module, "embed_out") and hasattr(module.embed_out,
- "weight") and not module.embed_out.weight.is_meta and isinstance(
- module.embed_out, torch.nn.Linear):
- module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
- return module
- if checkpoint_dict is not None and not config.replace_with_kernel_inject:
- # AutoTP shard loading
- checkpoint = checkpoint_dict["checkpoints"]
- pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
- for i in range(len(checkpoint)):
- checkpoint_file = os.path.join(config.base_dir, checkpoint[i])
- replaced_module = replace_module(model=model,
- orig_class=orig_layer_impl,
- replace_fn=replace_fn,
- _replace_policy=config.injection_policy_tuple,
- checkpoint=checkpoint_file)
- pbar.update(1)
- gc.collect()
- replaced_module = set_lm_head(replaced_module)
- else:
- replaced_module = replace_module(model=model,
- orig_class=orig_layer_impl,
- replace_fn=replace_fn,
- _replace_policy=config.injection_policy_tuple)
- quantizer = GroupQuantizer(q_int8=quantize)
- world_size = dist.get_world_size() if dist.is_initialized() else 1
- rank = dist.get_rank() if dist.is_initialized() else 0
- if checkpoint_dict is not None and config.replace_with_kernel_inject:
- assert container_g.ckpt_load_enabled, \
- f"Meta Tensor checkpoint loading not supported in {container_g.__class__.__name__} container"
- start_time = time.time()
- checkpoint = checkpoint_dict['checkpoints']
- ckpt_list = checkpoint["tp"] if type(checkpoint) is dict else checkpoint
- ckpt_type = checkpoint_dict.get('parallelization', 'pp')
- ckpt_mp_size = checkpoint_dict.get('tp_size', len(ckpt_list))
- ckpt_mp_size = checkpoint_dict.get('mp_size', ckpt_mp_size)
- base_dir1 = checkpoint_dict.get('base_dir', config.base_dir)
- if ckpt_type == 'pp' and type(checkpoint) is list:
- pbar = tqdm.tqdm(total=len(checkpoint), desc=f"Loading {len(checkpoint)} checkpoint shards")
- for i in range(len(checkpoint)):
- sd = [torch.load(os.path.join(base_dir1, checkpoint[i]), map_location='cpu')]
- load_model_with_checkpoint(replaced_module,
- sd,
- mp_replace,
- ckpt_type,
- ckpt_mp_size,
- quantizer,
- container=container_g)
- pbar.update(1)
- else:
- num_checkpoints = len(ckpt_list) // ckpt_mp_size
- tp_split_size = (world_size / ckpt_mp_size)
- sd_offset = int(rank / tp_split_size)
- sd_count = int((rank + max(1, tp_split_size)) / tp_split_size) - sd_offset
- pbar = tqdm.tqdm(total=num_checkpoints, desc=f"Loading {num_checkpoints} checkpoint shards")
- for i in range(num_checkpoints):
- pbar.update(1)
- ckpt_index = i * ckpt_mp_size + sd_offset
- ckpt_files = [
- os.path.join(base_dir1, ckpt_list[ckpt_index + j]) if base_dir1 else ckpt_list[ckpt_index + j]
- for j in range(sd_count)
- ]
- sds = [torch.load(ckpt_file, map_location='cpu') for ckpt_file in ckpt_files]
- load_model_with_checkpoint(replaced_module,
- sds,
- mp_replace,
- ckpt_type,
- ckpt_mp_size,
- quantizer,
- int(rank % tp_split_size),
- container=container_g)
- sds = [None for _ in sds]
- gc.collect()
- if "non_tp" in checkpoint:
- pbar = tqdm.tqdm(total=len(checkpoint["non_tp"]),
- desc=f"Loading {len(checkpoint['non_tp'])} checkpoint shards")
- for i in range(len(checkpoint["non_tp"])):
- pbar.update(1)
- ckpt_file = os.path.join(base_dir1,
- checkpoint["non_tp"][i]) if base_dir1 else checkpoint["non_tp"][i]
- sds = [torch.load(ckpt_file, map_location='cpu')]
- load_model_with_checkpoint(replaced_module,
- sds,
- mp_replace,
- ckpt_type,
- ckpt_mp_size,
- quantizer,
- int(rank % tp_split_size),
- container=container_g)
- sds = [None for _ in sds]
- gc.collect()
- set_lm_head(replaced_module)
- print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
- if config.save_mp_checkpoint_path is not None:
- from collections import OrderedDict
- import json
- num_partitions = 8
- if checkpoint_dict is None:
- ckpt_name = "ds_model"
- try:
- from transformers.models.bloom.modeling_bloom import BloomForCausalLM
- if isinstance(model, BloomForCausalLM):
- ckpt_name = "bloom"
- except ImportError:
- ckpt_name = "ds_model"
- else:
- ckpt_name = checkpoint_dict['type']
- if dist.is_initialized():
- dist.barrier()
- transformer_name = get_transformer_name(replaced_module)
- non_tp_ckpt_name = f'non-tp.pt'
- ckpt_files = [non_tp_ckpt_name]
- os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
- if not dist.is_initialized() or dist.get_rank() == 0:
- print("Saving tp-sharded checkpoints")
- torch.save(
- OrderedDict({k: v
- for k, v in dict(replaced_module.state_dict()).items()
- if transformer_name not in k}), f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
- dtype_reprs = {
- torch.float32: 'float32',
- torch.float16: 'float16',
- torch.int8: 'int8',
- torch.bfloat16: 'bfloat16'
- }
- ckpt_config = json.dumps({
- 'type': ckpt_name,
- 'base_dir': f'{config.save_mp_checkpoint_path}',
- 'checkpoints': {
- "non_tp": ckpt_files,
- "tp": [f'tp_{r:0>2d}_{m:0>2d}.pt' for m in range(num_partitions) for r in range(world_size)]
- },
- 'version': 1.0,
- 'parallelization': 'tp',
- 'tp_size': world_size,
- 'dtype': dtype_reprs[config.dtype]
- })
- with open(f"{config.save_mp_checkpoint_path}/ds_inference_config.json", "w") as cfg:
- cfg.write(ckpt_config)
- rep_sd = replaced_module.state_dict()
- for n, p in replaced_module.named_parameters():
- if hasattr(p, 'scale'):
- rep_sd[n] = [p, p.scale]
- keys = list(rep_sd.keys())
- partition_size = (len(keys) // num_partitions + 1)
- for m in range(num_partitions):
- torch.save(
- OrderedDict({
- k: [rep_sd[k], rep_sd[k].scale] if hasattr(rep_sd[k], 'scale') else rep_sd[k]
- for k in keys[m * partition_size:(m + 1) * partition_size] if transformer_name in k
- }), f'{config.save_mp_checkpoint_path}/tp_{rank:0>2d}_{m:0>2d}.pt')
- return replaced_module
- def revert_transformer_layer(orig_layer_impl, model, config, preln=False):
- """ Revert DeepSpeed's transformer layer back to original bert-style transformer layer
- Arguments:
- orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
- e.g., transformers.models.bert.modeling_bert.BertLayer or transformers.BertLayer
- model (torch.nn.Module): user's nn.module representing their model
- config (dict): model config containing hidden size, attention heads, etc.
- Returns:
- Updated nn.module with original bert-style transformer layers
- """
- def replace_fn(child, _replace_policy, layer_id):
- #from turing.nvidia_modelingpreln import BertLayer
- orig_module = orig_layer_impl(config)
- # copy relevant state from child -> original module
- qkvw = child.attn_qkvw.data
- qkvb = child.attn_qkvb.data
- qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
- qb, kb, vb = torch.chunk(qkvb, 3, axis=0)
- orig_module.attention.self.query.weight.data = qw
- orig_module.attention.self.query.bias.data = qb
- orig_module.attention.self.key.weight.data = kw
- orig_module.attention.self.key.bias.data = kb
- orig_module.attention.self.value.weight.data = vw
- orig_module.attention.self.value.bias.data = vb
- orig_module.attention.output.dense.weight.data = child.attn_ow.data
- orig_module.attention.output.dense.bias.data = child.attn_ob.data
- attn_ln_w = child.attn_nw.data
- attn_ln_b = child.attn_nb.data
- if preln:
- orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
- orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
- else:
- orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
- orig_module.attention.output.LayerNorm.bias.data = attn_ln_b
- inter_ff_w = child.inter_w.data
- inter_ff_b = child.inter_b.data
- if preln:
- orig_module.intermediate.dense_act.weight.data = inter_ff_w
- orig_module.intermediate.dense_act.bias.data = inter_ff_b
- else:
- orig_module.intermediate.dense.weight.data = inter_ff_w
- orig_module.intermediate.dense.bias.data = inter_ff_b
- orig_module.output.dense.weight.data = child.output_w.data
- orig_module.output.dense.bias.data = child.output_b.data
- transformer_ln_w = child.norm_w.data
- transformer_ln_b = child.norm_b.data
- if preln:
- orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
- orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
- else:
- orig_module.output.LayerNorm.weight.data = transformer_ln_w
- orig_module.output.LayerNorm.bias.data = transformer_ln_b
- return orig_module
- return replace_module(model=model,
- orig_class=deepspeed.DeepSpeedTransformerLayer,
- replace_fn=replace_fn,
- _replace_policy=None)
- def replace_module(model, orig_class, replace_fn, _replace_policy, checkpoint=None):
- """ Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
- Arguments:
- model (torch.nn.Module): the model to augment
- orig_class (torch.nn.Module): the module to search for
- replace_fn (method): a method to convert instances of ``orig_class`` to the
- desired type and return a new instance.
- Returns:
- A modified ``model``.
- """
- sd = None
- if checkpoint is not None:
- if checkpoint.endswith(".safetensors"):
- from safetensors.torch import load_file
- sd = load_file(checkpoint)
- else:
- sd = torch.load(checkpoint, map_location='cpu')
- policy = {}
- if orig_class is not None:
- policy.update({orig_class: (replace_fn, _replace_policy)})
- else:
- for plcy in replace_policies:
- # instantiate a throw-away policy in order to populate the _orig_layer_class
- _ = plcy(None)
- if isinstance(plcy._orig_layer_class, list):
- for orig_layer_class in plcy._orig_layer_class:
- policy.update({orig_layer_class: (replace_fn, plcy)})
- elif plcy._orig_layer_class is not None:
- policy.update({plcy._orig_layer_class: (replace_fn, plcy)})
- assert len(policy.items()) > 0,\
- "No default policy found! Please specify your policy injection_policy (like {BertLayer:HFBEertLayerPolicy})." +\
- "You can find some samples here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/replace_policy.py"
- replaced_module, _ = _replace_module(model, policy, state_dict=sd)
- return replaced_module
- from ..pipe import PipelineModule
- import re
- def skip_level_0_prefix(model, state_dict):
- model = str(model)
- key = re.search(r": (.*?)Model", model)
- if key is None:
- key = re.search(r": (.*?)Stack", model)
- if key is None:
- key = re.match(r"(.*?)Model", model)
- # if keys start with 'model.', don't skip level 0 prefix
- if state_dict is not None:
- for item in state_dict.keys():
- if re.match("^model[.]", item):
- return False
- if key is not None and key.group(1).lower() in ["bloom", "opt"]:
- return True
- return False
- def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_dict=None):
- """ Traverse model's children recursively and apply any transformations in ``policies``.
- Arguments:
- model (torch.nn.Module): model to augment
- policies (dict): Mapping of source class to replacement function.
- Returns:
- Modified ``model``.
- """
- for name, child in model.named_children():
- if child.__class__ in policies:
- replaced_module = policies[child.__class__][0](child,
- policies[child.__class__][-1],
- layer_id,
- prefix=prefix + name,
- state_dict=state_dict)
- setattr(model, name, replaced_module)
- if isinstance(model, PipelineModule):
- assert hasattr(model, 'forward_funcs'),\
- "we require pipe-module to have the list of fwd_functions"
- model.forward_funcs[model.fwd_map[name]] = replaced_module
- layer_id += 1
- else:
- checking_key = prefix + name + '.'
- if Loading.is_load_module(child) and state_dict is not None:
- if any(checking_key in item for item in state_dict):
- Loading.load(
- child,
- state_dict,
- checking_key,
- )
- else:
- continue
- if len(child._buffers) != 0 and state_dict is not None:
- Loading.load_buffer(child, state_dict, checking_key)
- _, layer_id = _replace_module(child,
- policies,
- prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \
- prefix + name + '.',
- layer_id=layer_id,
- level_id=level_id + 1,
- state_dict=state_dict)
- # Add the reset_cache func to the model, so that it can be called in the beginning of text-generation.
- model.reset_cache = transformer_inference.DeepSpeedTransformerInference.reset_cache
- return model, layer_id
|