chatglm微调工具.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from toolbox import CatchException, update_ui, promote_file_to_downloadzone
  2. from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
  3. import datetime, json
  4. def fetch_items(list_of_items, batch_size):
  5. for i in range(0, len(list_of_items), batch_size):
  6. yield list_of_items[i:i + batch_size]
  7. def string_to_options(arguments):
  8. import argparse
  9. import shlex
  10. # Create an argparse.ArgumentParser instance
  11. parser = argparse.ArgumentParser()
  12. # Add command-line arguments
  13. parser.add_argument("--llm_to_learn", type=str, help="LLM model to learn", default="gpt-3.5-turbo")
  14. parser.add_argument("--prompt_prefix", type=str, help="Prompt prefix", default='')
  15. parser.add_argument("--system_prompt", type=str, help="System prompt", default='')
  16. parser.add_argument("--batch", type=int, help="System prompt", default=50)
  17. parser.add_argument("--pre_seq_len", type=int, help="pre_seq_len", default=50)
  18. parser.add_argument("--learning_rate", type=float, help="learning_rate", default=2e-2)
  19. parser.add_argument("--num_gpus", type=int, help="num_gpus", default=1)
  20. parser.add_argument("--json_dataset", type=str, help="json_dataset", default="")
  21. parser.add_argument("--ptuning_directory", type=str, help="ptuning_directory", default="")
  22. # Parse the arguments
  23. args = parser.parse_args(shlex.split(arguments))
  24. return args
  25. @CatchException
  26. def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
  27. """
  28. txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
  29. llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
  30. plugin_kwargs 插件模型的参数
  31. chatbot 聊天显示框的句柄,用于显示给用户
  32. history 聊天历史,前情提要
  33. system_prompt 给gpt的静默提醒
  34. web_port 当前软件运行的端口号
  35. """
  36. history = [] # 清空历史,以免输入溢出
  37. chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
  38. if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
  39. args = plugin_kwargs.get("advanced_arg", None)
  40. if args is None:
  41. chatbot.append(("没给定指令", "退出"))
  42. yield from update_ui(chatbot=chatbot, history=history); return
  43. else:
  44. arguments = string_to_options(arguments=args)
  45. dat = []
  46. with open(txt, 'r', encoding='utf8') as f:
  47. for line in f.readlines():
  48. json_dat = json.loads(line)
  49. dat.append(json_dat["content"])
  50. llm_kwargs['llm_model'] = arguments.llm_to_learn
  51. for batch in fetch_items(dat, arguments.batch):
  52. res = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
  53. inputs_array=[f"{arguments.prompt_prefix}\n\n{b}" for b in (batch)],
  54. inputs_show_user_array=[f"Show Nothing" for _ in (batch)],
  55. llm_kwargs=llm_kwargs,
  56. chatbot=chatbot,
  57. history_array=[[] for _ in (batch)],
  58. sys_prompt_array=[arguments.system_prompt for _ in (batch)],
  59. max_workers=10 # OpenAI所允许的最大并行过载
  60. )
  61. with open(txt+'.generated.json', 'a+', encoding='utf8') as f:
  62. for b, r in zip(batch, res[1::2]):
  63. f.write(json.dumps({"content":b, "summary":r}, ensure_ascii=False)+'\n')
  64. promote_file_to_downloadzone(txt+'.generated.json', rename_file='generated.json', chatbot=chatbot)
  65. return
  66. @CatchException
  67. def 启动微调(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
  68. """
  69. txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
  70. llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
  71. plugin_kwargs 插件模型的参数
  72. chatbot 聊天显示框的句柄,用于显示给用户
  73. history 聊天历史,前情提要
  74. system_prompt 给gpt的静默提醒
  75. web_port 当前软件运行的端口号
  76. """
  77. import subprocess
  78. history = [] # 清空历史,以免输入溢出
  79. chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
  80. if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
  81. args = plugin_kwargs.get("advanced_arg", None)
  82. if args is None:
  83. chatbot.append(("没给定指令", "退出"))
  84. yield from update_ui(chatbot=chatbot, history=history); return
  85. else:
  86. arguments = string_to_options(arguments=args)
  87. pre_seq_len = arguments.pre_seq_len # 128
  88. learning_rate = arguments.learning_rate # 2e-2
  89. num_gpus = arguments.num_gpus # 1
  90. json_dataset = arguments.json_dataset # 't_code.json'
  91. ptuning_directory = arguments.ptuning_directory # '/home/hmp/ChatGLM2-6B/ptuning'
  92. command = f"torchrun --standalone --nnodes=1 --nproc-per-node={num_gpus} main.py \
  93. --do_train \
  94. --train_file AdvertiseGen/{json_dataset} \
  95. --validation_file AdvertiseGen/{json_dataset} \
  96. --preprocessing_num_workers 20 \
  97. --prompt_column content \
  98. --response_column summary \
  99. --overwrite_cache \
  100. --model_name_or_path THUDM/chatglm2-6b \
  101. --output_dir output/clothgen-chatglm2-6b-pt-{pre_seq_len}-{learning_rate} \
  102. --overwrite_output_dir \
  103. --max_source_length 256 \
  104. --max_target_length 256 \
  105. --per_device_train_batch_size 1 \
  106. --per_device_eval_batch_size 1 \
  107. --gradient_accumulation_steps 16 \
  108. --predict_with_generate \
  109. --max_steps 100 \
  110. --logging_steps 10 \
  111. --save_steps 20 \
  112. --learning_rate {learning_rate} \
  113. --pre_seq_len {pre_seq_len} \
  114. --quantization_bit 4"
  115. process = subprocess.Popen(command, shell=True, cwd=ptuning_directory)
  116. try:
  117. process.communicate(timeout=3600*24)
  118. except subprocess.TimeoutExpired:
  119. process.kill()
  120. return