123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from torch import nn
- from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
- from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
- from deepspeed.model_implementations.transformers.ds_bert import DeepSpeedBERTInference
- from deepspeed.model_implementations.transformers.ds_megatron_gpt import DeepSpeedMegatronGPTInference
- from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInference
- from deepspeed.model_implementations.transformers.ds_llama2 import DeepSpeedLlama2Inference
- import deepspeed.ops.transformer as transformer_inference
- from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding, RMSNormalize
- import torch
- import gc
- from deepspeed.accelerator import get_accelerator
- import re
- def load_model_with_checkpoint(r_module,
- sd,
- mp_replace,
- ckpt_type,
- ckpt_mp_size,
- weight_quantizer=None,
- rank=0,
- container=None):
- error_msgs = []
- def prefix_check():
- # if keys start with 'model.' or 'transformer.', don't skip level 0 prefix
- for key in sd[0].keys():
- # OPT models
- if re.match("^model[.]", key):
- return False
- # BLOOM models
- if re.match("^transformer[.]", key):
- return False
- return True
- skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix
- def transpose(data):
- with torch.no_grad():
- data = data.contiguous()
- data1 = data.transpose(-1, -2).reshape(-1)
- data.reshape(-1).copy_(data1)
- data1 = None
- return data.reshape(data.shape[-1], data.shape[-2])
- def load(module, prefix):
- args = (sd[0], prefix, {}, True, [], [], error_msgs)
- if hasattr(module, 'weight'):
- module.weight = mp_replace.copy(module.weight.data, sd[0][prefix + 'weight'])
- if prefix + 'bias' in sd[0].keys():
- if module.bias.data.is_meta:
- # meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
- module.bias = torch.nn.parameter.Parameter(data=torch.empty_like(module.bias.data, device="cpu"),
- requires_grad=module.bias.data.requires_grad)
- module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
- args = None
- gc.collect()
- def load_transformer_layer(module, prefix):
- if ckpt_type == "tp":
- def load_parameters(module, prefix):
- for n, p in module.named_parameters():
- if prefix + n in sd[0] and len(n.split('.')) == 1:
- if type(sd[0][prefix + n]) is list:
- tmp_data, scale = sd[0][prefix + n]
- tmp_data = tmp_data
- scale = scale.to(get_accelerator().current_device_name())
- # set the quantizer number of groups using the checkpoint scale shape
- weight_quantizer.num_groups = scale.shape[0]
- else:
- tmp_data = sd[0][prefix + n].to(get_accelerator().current_device_name())
- scale = None
- src_shape = tmp_data.shape
- dst_shape = p.shape
- inner_dim = 1 if tmp_data.dtype == torch.int8 else 0
- outer_dim = 0 if tmp_data.dtype == torch.int8 else 1
- if (len(src_shape) == 2 and len(dst_shape) == 2):
- if (src_shape[inner_dim] == dst_shape[0] and src_shape[outer_dim] == dst_shape[1]):
- if tmp_data.dtype != torch.int8:
- p = weight_quantizer.quantize(
- transpose(tmp_data) if weight_quantizer.q_int8 else tmp_data)
- else:
- p = torch.nn.parameter.Parameter(tmp_data, requires_grad=False)
- p.scale = scale
- setattr(module, n, p)
- else:
- dim = inner_dim if src_shape[inner_dim] != dst_shape[0] else outer_dim
- dim1 = 0 if src_shape[inner_dim] != dst_shape[0] else 1
- if src_shape[dim] > dst_shape[dim1]:
- weight_partition = torch.split(tmp_data, dst_shape[dim1], dim=dim)[rank].to(
- get_accelerator().current_device_name())
- assert tmp_data.dtype != torch.int8 or scale.numel() > weight_quantizer.num_groups * (rank+1), \
- '''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!\
- Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
- scale = scale.view(-1)[weight_quantizer.num_groups * (rank + 1):].reshape(
- weight_quantizer.num_groups, -1).contiguous()
- else:
- assert tmp_data.dtype != torch.int8, \
- '''Merging of the checkpoints are not supported when using INT8 checkpoint! \
- Please use a as many GPUs as TP-size for the checkpoint'''
- all_data = [
- sd[j][prefix + n] if type(sd[j][prefix + n]) is list else sd[j][prefix + n].to(
- get_accelerator().current_device_name()) for j in range(len(sd))
- ]
- # Check if the weight tensor is for the QKV parameter
- if src_shape[1] == (3 * src_shape[0]) // ckpt_mp_size:
- qkv_size = src_shape[outer_dim] // 3
- src_split = [
- torch.split(src[0].data, qkv_size, dim=outer_dim) for src in all_data
- ]
- weight_partition = torch.cat([
- torch.cat([qkv_s[i] for qkv_s in src_split], axis=outer_dim)
- for i in range(len(src_split[0]))
- ],
- dim=dim)
- else:
- weight_partition = torch.cat([
- ad[0].to(get_accelerator().current_device_name())
- if type(ad) is list else ad for ad in all_data
- ],
- dim=dim)
- if tmp_data.dtype == torch.int8:
- scale = torch.cat(
- [ad[1].to(get_accelerator().current_device_name()) for ad in all_data],
- dim=dim)
- if tmp_data.dtype != torch.int8:
- weight_partition = weight_quantizer.quantize(
- transpose(weight_partition), \
- parallel_dim=(0 if dim == 1 else 1)) if weight_quantizer.q_int8 else \
- weight_quantizer.quantize(weight_partition)
- else:
- weight_partition = torch.nn.parameter.Parameter(weight_partition,
- requires_grad=False)
- weight_partition.scale = scale
- setattr(module, n, weight_partition)
- else:
- if src_shape[0] == dst_shape[0]:
- p.data.copy_(tmp_data)
- else:
- if src_shape[0] > dst_shape[0]:
- bias_split = torch.split(tmp_data, dst_shape[-1])[rank].to(
- get_accelerator().current_device_name()).contiguous()
- p.data.copy_(bias_split)
- else:
- # Check if the weight tensor is for the QKV parameter
- if src_shape[0] == (3 * r_module.config.hidden_size) // ckpt_mp_size:
- qkv_size = src_shape[0] // 3
- src_split = [
- torch.split(sd[j][prefix + n], qkv_size, dim=0) for j in range(len(sd))
- ]
- p.data.copy_(
- torch.cat([
- torch.cat([qkv_s[i] for qkv_s in src_split], axis=0)
- for i in range(len(src_split[0]))
- ],
- dim=0).to(get_accelerator().current_device_name()).contiguous())
- else:
- p.data.copy_(
- torch.cat([sd[j][prefix + n] for j in range(len(sd))],
- dim=0).to(get_accelerator().current_device_name()).contiguous())
- load_parameters(module, prefix)
- for n, child in module.named_children():
- load_parameters(child, prefix + n + '.')
- else:
- container.load_params(module, sd[0], weight_quantizer, mp_replace, prefix)
- try:
- import transformers
- OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
- if hasattr(transformers.models, "llama"):
- LlamaRMSNorm = transformers.models.llama.modeling_llama.LlamaRMSNorm
- else:
- LlamaRMSNorm = None
- except:
- OPTLearnedPositionalEmbedding = None
- try:
- from fairscale.nn.model_parallel.layers import (
- ColumnParallelLinear,
- ParallelEmbedding,
- RowParallelLinear,
- )
- except:
- ColumnParallelLinear = None
- ParallelEmbedding = None
- RowParallelLinear = None
- try:
- from llama.model import RMSNorm
- except:
- RMSNorm = None
- layer_policies = {
- nn.Linear: load,
- nn.Embedding: load,
- nn.LayerNorm: load,
- EmbeddingLayer: load,
- LinearLayer: load,
- Normalize: load,
- transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
- DeepSpeedBloomInference: load_transformer_layer,
- DeepSpeedGPTInference: load_transformer_layer,
- DeepSpeedBERTInference: load_transformer_layer,
- DeepSpeedMegatronGPTInference: load_transformer_layer,
- DeepSpeedOPTInference: load_transformer_layer,
- DeepSpeedLlama2Inference: load_transformer_layer,
- OPTLearnedPositionalEmbedding: load,
- OPTEmbedding: load,
- LlamaRMSNorm: load,
- RMSNormalize: load,
- ColumnParallelLinear: load,
- ParallelEmbedding: load,
- RowParallelLinear: load,
- RMSNorm: load
- }
- all_ds_ids = {}
- def load_module_recursive(module, prefix='', level=0):
- for name, child in module.named_children():
- if child.__class__ in layer_policies:
- checking_key = prefix + name + '.'
- if not any(checking_key in item for item in sd[0].keys()):
- if hasattr(child, 'weight') and \
- (hasattr(child.weight, 'ds_id') and \
- child.weight.ds_id in all_ds_ids):
- prefix1 = all_ds_ids[child.weight.ds_id]
- if child.__class__ is nn.Linear:
- child = LinearLayer(weight=all_ds_ids[child.weight.ds_id])
- setattr(module, name, child)
- continue
- child_params = list(child.parameters())
- if len(child_params) > 0 and (child_params[0].numel() == 0 or child_params[0].is_meta):
- if child.weight.is_meta:
- ds_shape = child.weight.shape
- else:
- ds_shape = child.weight.ds_shape
- if child.__class__ is nn.LayerNorm:
- child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
- setattr(module, name, child)
- elif child.__class__ in [nn.Linear, ColumnParallelLinear, RowParallelLinear]:
- child = LinearLayer(weight_shape=child.weight.shape, bias=child.bias)
- setattr(module, name, child)
- elif child.__class__ is OPTLearnedPositionalEmbedding:
- child = OPTEmbedding(weight_shape=ds_shape)
- setattr(module, name, child)
- elif child.__class__ in [LlamaRMSNorm, RMSNorm]:
- child = RMSNormalize(dim=ds_shape[-1],
- dtype=child.weight.dtype,
- eps=child.eps if hasattr(child, 'eps') else child.variance_epsilon)
- setattr(module, name, child)
- else:
- ds_id = None
- if hasattr(child.weight, 'ds_id'):
- ds_id = child.weight.ds_id
- child = EmbeddingLayer(weight_shape=ds_shape, dtype=child.weight.dtype)
- if ds_id is not None:
- all_ds_ids[ds_id] = child.weight
- setattr(module, name, child)
- layer_policies[child.__class__](child, prefix + name + '.')
- else:
- load_module_recursive(
- child,
- prefix if (level == 0 and ckpt_type == 'pp') and skip_level_0_prefix else \
- prefix + name + '.',
- level + 1)
- load_module_recursive(r_module)
- for sd_ in sd:
- del sd_
- sd = None
- gc.collect()
|