apply_delta.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. """
  2. Apply the delta weights on top of a base model.
  3. Usage:
  4. python3 apply_delta.py --base-model-path path/to/hf_llama/ --target-model-path path/to/gorilla-7b-hf-v0 --delta-path path/to/models--gorilla-llm--gorilla-7b-hf-delta-v0
  5. Thanks to LMSYS for the template of this code.
  6. """
  7. import argparse
  8. import gc
  9. import glob
  10. import json
  11. import os
  12. import shutil
  13. import tempfile
  14. from huggingface_hub import snapshot_download
  15. import torch
  16. from torch import nn
  17. from tqdm import tqdm
  18. from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
  19. GB = 1 << 30
  20. def split_files(model_path, tmp_path, split_size):
  21. if not os.path.exists(model_path):
  22. model_path = snapshot_download(repo_id=model_path)
  23. if not os.path.exists(tmp_path):
  24. os.makedirs(tmp_path)
  25. file_pattern = os.path.join(model_path, "pytorch_model-*.bin")
  26. files = glob.glob(file_pattern)
  27. part = 0
  28. try:
  29. for file_path in tqdm(files):
  30. state_dict = torch.load(file_path)
  31. new_state_dict = {}
  32. current_size = 0
  33. for name, param in state_dict.items():
  34. param_size = param.numel() * param.element_size()
  35. if current_size + param_size > split_size:
  36. new_file_name = f"pytorch_model-{part}.bin"
  37. new_file_path = os.path.join(tmp_path, new_file_name)
  38. torch.save(new_state_dict, new_file_path)
  39. current_size = 0
  40. new_state_dict = None
  41. gc.collect()
  42. new_state_dict = {}
  43. part += 1
  44. new_state_dict[name] = param
  45. current_size += param_size
  46. new_file_name = f"pytorch_model-{part}.bin"
  47. new_file_path = os.path.join(tmp_path, new_file_name)
  48. torch.save(new_state_dict, new_file_path)
  49. new_state_dict = None
  50. gc.collect()
  51. new_state_dict = {}
  52. part += 1
  53. except Exception as e:
  54. print(f"An error occurred during split_files: {e}")
  55. shutil.rmtree(tmp_path)
  56. raise
  57. def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):
  58. delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
  59. delta_config = AutoConfig.from_pretrained(delta_path)
  60. if os.path.exists(target_model_path):
  61. shutil.rmtree(target_model_path)
  62. os.makedirs(target_model_path)
  63. split_size = 4 * GB
  64. with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
  65. print(f"Split files for the base model to {tmp_base_path}")
  66. split_files(base_model_path, tmp_base_path, split_size)
  67. print(f"Split files for the delta weights to {tmp_delta_path}")
  68. split_files(delta_path, tmp_delta_path, split_size)
  69. base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin")
  70. base_files = glob.glob(base_pattern)
  71. delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin")
  72. delta_files = glob.glob(delta_pattern)
  73. delta_state_dict = torch.load(delta_files[0])
  74. print("Applying the delta")
  75. weight_map = {}
  76. total_size = 0
  77. for i, base_file in tqdm(enumerate(base_files)):
  78. state_dict = torch.load(base_file)
  79. file_name = f"pytorch_model-{i}.bin"
  80. for name, param in state_dict.items():
  81. if name not in delta_state_dict:
  82. for delta_file in delta_files:
  83. delta_state_dict = torch.load(delta_file)
  84. gc.collect()
  85. if name in delta_state_dict:
  86. break
  87. state_dict[name] += delta_state_dict[name]
  88. weight_map[name] = file_name
  89. total_size += param.numel() * param.element_size()
  90. gc.collect()
  91. torch.save(state_dict, os.path.join(target_model_path, file_name))
  92. with open(
  93. os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w"
  94. ) as f:
  95. json.dump(
  96. {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f
  97. )
  98. print(f"Saving the target model to {target_model_path}")
  99. delta_tokenizer.save_pretrained(target_model_path)
  100. delta_config.save_pretrained(target_model_path)
  101. def apply_delta(base_model_path, target_model_path, delta_path):
  102. print(f"Loading the delta weights from {delta_path}")
  103. delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
  104. delta = AutoModelForCausalLM.from_pretrained(
  105. delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
  106. )
  107. print(f"Loading the base model from {base_model_path}")
  108. base = AutoModelForCausalLM.from_pretrained(
  109. base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
  110. )
  111. print("Applying the delta")
  112. for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
  113. assert name in delta.state_dict()
  114. param.data += delta.state_dict()[name]
  115. print(f"Saving the target model to {target_model_path}")
  116. base.save_pretrained(target_model_path)
  117. delta_tokenizer.save_pretrained(target_model_path)
  118. if __name__ == "__main__":
  119. parser = argparse.ArgumentParser()
  120. parser.add_argument("--base-model-path", type=str, required=True)
  121. parser.add_argument("--target-model-path", type=str, required=True)
  122. parser.add_argument("--delta-path", type=str, required=True)
  123. parser.add_argument(
  124. "--low-cpu-mem",
  125. action="store_true",
  126. help="Lower the cpu memory usage. This will split large files and use "
  127. "disk as swap to reduce the memory usage below 10GB.",
  128. )
  129. args = parser.parse_args()
  130. if args.low_cpu_mem:
  131. apply_delta_low_cpu_mem(
  132. args.base_model_path, args.target_model_path, args.delta_path
  133. )
  134. else:
  135. apply_delta(args.base_model_path, args.target_model_path, args.delta_path)