llamafy_baichuan2.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # coding=utf-8
  2. # Copyright 2024 the LlamaFactory team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import json
  16. import os
  17. from collections import OrderedDict
  18. from typing import Any, Dict
  19. import fire
  20. import torch
  21. from safetensors.torch import save_file
  22. from tqdm import tqdm
  23. from transformers.modeling_utils import (
  24. SAFE_WEIGHTS_INDEX_NAME,
  25. SAFE_WEIGHTS_NAME,
  26. WEIGHTS_INDEX_NAME,
  27. WEIGHTS_NAME,
  28. shard_checkpoint,
  29. )
  30. CONFIG_NAME = "config.json"
  31. def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
  32. baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
  33. for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
  34. if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
  35. shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
  36. baichuan2_state_dict.update(shard_weight)
  37. llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
  38. for key, value in tqdm(baichuan2_state_dict.items(), desc="Convert format"):
  39. if "W_pack" in key:
  40. proj_size = value.size(0) // 3
  41. llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
  42. llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size : 2 * proj_size, :]
  43. llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2 * proj_size :, :]
  44. elif "lm_head" in key:
  45. llama2_state_dict[key] = torch.nn.functional.normalize(value)
  46. else:
  47. llama2_state_dict[key] = value
  48. weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
  49. shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)
  50. for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
  51. if save_safetensors:
  52. save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
  53. else:
  54. torch.save(shard, os.path.join(output_dir, shard_file))
  55. if index is None:
  56. print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
  57. else:
  58. index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
  59. with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
  60. json.dump(index, f, indent=2, sort_keys=True)
  61. print("Model weights saved in {}".format(output_dir))
  62. def save_config(input_dir: str, output_dir: str):
  63. with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
  64. llama2_config_dict: Dict[str, Any] = json.load(f)
  65. llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
  66. llama2_config_dict.pop("auto_map", None)
  67. llama2_config_dict.pop("tokenizer_class", None)
  68. llama2_config_dict["model_type"] = "llama"
  69. with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
  70. json.dump(llama2_config_dict, f, indent=2)
  71. print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
  72. def llamafy_baichuan2(
  73. input_dir: str,
  74. output_dir: str,
  75. shard_size: str = "2GB",
  76. save_safetensors: bool = True,
  77. ):
  78. r"""
  79. Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
  80. Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
  81. Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
  82. """
  83. try:
  84. os.makedirs(output_dir, exist_ok=False)
  85. except Exception as e:
  86. raise print("Output dir already exists", e)
  87. save_weight(input_dir, output_dir, shard_size, save_safetensors)
  88. save_config(input_dir, output_dir)
  89. if __name__ == "__main__":
  90. fire.Fire(llamafy_baichuan2)