llama_pro.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # coding=utf-8
  2. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
  3. #
  4. # This code is inspired by the Tencent's LLaMA-Pro library.
  5. # https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. import json
  19. import os
  20. from collections import OrderedDict
  21. from typing import TYPE_CHECKING
  22. import fire
  23. import torch
  24. from safetensors.torch import save_file
  25. from tqdm import tqdm
  26. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  27. from transformers.modeling_utils import (
  28. SAFE_WEIGHTS_INDEX_NAME,
  29. SAFE_WEIGHTS_NAME,
  30. WEIGHTS_INDEX_NAME,
  31. WEIGHTS_NAME,
  32. shard_checkpoint,
  33. )
  34. if TYPE_CHECKING:
  35. from transformers import PretrainedConfig, PreTrainedModel
  36. def change_name(name: str, old_index: int, new_index: int) -> str:
  37. return name.replace(".{:d}.".format(old_index), ".{:d}.".format(new_index))
  38. def block_expansion(
  39. model_name_or_path: str,
  40. output_dir: str,
  41. num_expand: int,
  42. shard_size: str = "2GB",
  43. save_safetensors: bool = True,
  44. ):
  45. r"""
  46. Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
  47. Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
  48. """
  49. config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
  50. num_layers = getattr(config, "num_hidden_layers")
  51. setattr(config, "num_hidden_layers", num_layers + num_expand)
  52. config.save_pretrained(output_dir)
  53. tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
  54. tokenizer.save_pretrained(output_dir)
  55. config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one
  56. if save_safetensors:
  57. setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
  58. model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
  59. model_name_or_path,
  60. config=config,
  61. torch_dtype="auto",
  62. trust_remote_code=True,
  63. low_cpu_mem_usage=True,
  64. )
  65. state_dict = model.state_dict()
  66. if num_layers % num_expand != 0:
  67. raise ValueError("`num_layers` {} should be divisible by `num_expand` {}.".format(num_layers, num_expand))
  68. split = num_layers // num_expand
  69. layer_cnt = 0
  70. output_state_dict = OrderedDict()
  71. for i in range(num_layers):
  72. for key, value in state_dict.items():
  73. if ".{:d}.".format(i) in key:
  74. output_state_dict[change_name(key, i, layer_cnt)] = value
  75. print("Add layer {} copied from layer {}".format(layer_cnt, i))
  76. layer_cnt += 1
  77. if (i + 1) % split == 0:
  78. for key, value in state_dict.items():
  79. if ".{:d}.".format(i) in key:
  80. if "down_proj" in key or "o_proj" in key:
  81. output_state_dict[change_name(key, i, layer_cnt)] = torch.zeros_like(value)
  82. else:
  83. output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
  84. print("Add layer {} expanded from layer {}".format(layer_cnt, i))
  85. layer_cnt += 1
  86. for key, value in state_dict.items():
  87. if key not in output_state_dict:
  88. output_state_dict[key] = value
  89. weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
  90. shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
  91. for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
  92. if save_safetensors:
  93. save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
  94. else:
  95. torch.save(shard, os.path.join(output_dir, shard_file))
  96. if index is None:
  97. print("Model weights saved in {}".format(os.path.join(output_dir, weights_name)))
  98. else:
  99. index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
  100. with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
  101. json.dump(index, f, indent=2, sort_keys=True)
  102. print("Model weights saved in {}".format(output_dir))
  103. print("- Fine-tune this model with:")
  104. print("model_name_or_path: {}".format(output_dir))
  105. print("finetuning_type: freeze")
  106. print("freeze_trainable_layers: {}".format(num_expand))
  107. print("use_llama_pro: true")
  108. if __name__ == "__main__":
  109. fire.Fire(block_expansion)