inference_local.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import json
  2. import torch
  3. from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
  4. from openfunctions_utils import strip_function_calls, parse_function_call
  5. def get_prompt(user_query: str, functions: list = []) -> str:
  6. """
  7. Generates a conversation prompt based on the user's query and a list of functions.
  8. Parameters:
  9. - user_query (str): The user's query.
  10. - functions (list): A list of functions to include in the prompt.
  11. Returns:
  12. - str: The formatted conversation prompt.
  13. """
  14. system = "You are an AI programming assistant, utilizing the Gorilla LLM model, developed by Gorilla LLM, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer."
  15. if len(functions) == 0:
  16. return f"{system}\n### Instruction: <<question>> {user_query}\n### Response: "
  17. functions_string = json.dumps(functions)
  18. return f"{system}\n### Instruction: <<function>>{functions_string}\n<<question>>{user_query}\n### Response: "
  19. def format_response(response: str):
  20. """
  21. Formats the response from the OpenFunctions model.
  22. Parameters:
  23. - response (str): The response generated by the LLM.
  24. Returns:
  25. - str: The formatted response.
  26. - dict: The function call(s) extracted from the response.
  27. """
  28. function_call_dicts = None
  29. try:
  30. response = strip_function_calls(response)
  31. # Parallel function calls returned as a str, list[dict]
  32. if len(response) > 1:
  33. function_call_dicts = []
  34. for function_call in response:
  35. function_call_dicts.append(parse_function_call(function_call))
  36. response = ", ".join(response)
  37. # Single function call returned as a str, dict
  38. else:
  39. function_call_dicts = parse_function_call(response[0])
  40. response = response[0]
  41. except Exception as e:
  42. # Just faithfully return the generated response str to the user
  43. pass
  44. return response, function_call_dicts
  45. # Device setup
  46. device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
  47. torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
  48. # Model and tokenizer setup
  49. model_id : str = "gorilla-llm/gorilla-openfunctions-v2"
  50. tokenizer = AutoTokenizer.from_pretrained(model_id)
  51. model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
  52. # Move model to device
  53. model.to(device)
  54. # Pipeline setup
  55. pipe = pipeline(
  56. "text-generation",
  57. model=model,
  58. tokenizer=tokenizer,
  59. max_new_tokens=128,
  60. batch_size=16,
  61. torch_dtype=torch_dtype,
  62. device=device,
  63. )
  64. # Example usage 1
  65. # This should return 2 functions with the right argument
  66. query_1: str = "What's the weather like in the two cities of Boston and San Francisco?"
  67. functions_1 = [
  68. {
  69. "name": "get_current_weather",
  70. "description": "Get the current weather in a given location",
  71. "parameters": {
  72. "type": "object",
  73. "properties": {
  74. "location": {
  75. "type": "string",
  76. "description": "The city and state, e.g. San Francisco, CA",
  77. },
  78. "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
  79. },
  80. "required": ["location"],
  81. },
  82. }
  83. ]
  84. # Example usage 2
  85. # This should return an error since the function cann't help with the prompt
  86. query_2: str = "What is the freezing point of water at a pressure of 10 kPa?"
  87. functions_2 = [{"name": "thermodynamics.calculate_boiling_point", "description": "Calculate the boiling point of a given substance at a specific pressure.", "parameters": {"type": "object", "properties": {"substance": {"type": "string", "description": "The substance for which to calculate the boiling point."}, "pressure": {"type": "number", "description": "The pressure at which to calculate the boiling point."}, "unit": {"type": "string", "description": "The unit of the pressure. Default is 'kPa'."}}, "required": ["substance", "pressure"]}}]
  88. # Generate prompt and obtain model output
  89. prompt_1 = get_prompt(query_1, functions=functions_1)
  90. output_1 = pipe(prompt_1)
  91. fn_call_string, function_call_dict = format_response(output_1[0]['generated_text'])
  92. print("--------------------")
  93. print(f"Function call strings 1(s): {fn_call_string}")
  94. print("--------------------")
  95. print(f"OpenAI compatible `function_call`: {function_call_dict}")
  96. print("--------------------")