123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- # coding=utf-8
- # Copyright 2024 Tencent Inc. and the LlamaFactory team.
- #
- # This code is inspired by the Tencent's LLaMA-Pro library.
- # https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import os
- from collections import OrderedDict
- from typing import TYPE_CHECKING
- import fire
- import torch
- from safetensors.torch import save_file
- from tqdm import tqdm
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
- from transformers.modeling_utils import (
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- WEIGHTS_INDEX_NAME,
- WEIGHTS_NAME,
- shard_checkpoint,
- )
- if TYPE_CHECKING:
- from transformers import PretrainedConfig, PreTrainedModel
- def change_name(name: str, old_index: int, new_index: int) -> str:
- return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
- def block_expansion(
- model_name_or_path: str,
- output_dir: str,
- num_expand: int,
- shard_size: str = "2GB",
- save_safetensors: bool = True,
- ):
- r"""
- Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
- Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
- """
- config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
- num_layers = getattr(config, "num_hidden_layers")
- setattr(config, "num_hidden_layers", num_layers + num_expand)
- config.save_pretrained(output_dir)
- tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
- tokenizer.save_pretrained(output_dir)
- config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
- if save_safetensors:
- setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
- model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
- model_name_or_path,
- config=config,
- torch_dtype="auto",
- trust_remote_code=True,
- low_cpu_mem_usage=True,
- )
- state_dict = model.state_dict()
- if num_layers % num_expand != 0:
- raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
- split = num_layers // num_expand
- layer_cnt = 0
- output_state_dict = OrderedDict()
- for i in range(num_layers):
- for key, value in state_dict.items():
- if ".{:d}.".format(i) in key:
- output_state_dict[change_name(key, i, layer_cnt)] = value
- print("Add layer {} copied from layer {}".format(layer_cnt, i))
- layer_cnt += 1
- if (i + 1) % split == 0:
- for key, value in state_dict.items():
- if ".{:d}.".format(i) in key:
- if "down_proj" in key or "o_proj" in key:
- output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
- else:
- output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
- print("Add layer {} expanded from layer {}".format(layer_cnt, i))
- layer_cnt += 1
- for key, value in state_dict.items():
- if key not in output_state_dict:
- output_state_dict[key] = value
- weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
- shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
- for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
- if save_safetensors:
- save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
- else:
- torch.save(shard, os.path.join(output_dir, shard_file))
- if index is None:
- print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
- else:
- index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
- with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
- json.dump(index, f, indent=2, sort_keys=True)
- print("Model weights saved in {}".format(output_dir))
- print("- Fine-tune this model with:")
- print("model_name_or_path: {}".format(output_dir))
- print("finetuning_type: freeze")
- print("freeze_trainable_layers: {}".format(num_expand))
- print("use_llama_pro: true")
- if __name__ == "__main__":
- fire.Fire(block_expansion)
|