claude_llm.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import re
  2. import json
  3. import time
  4. import anthropic
  5. from typing import List, Dict, Any
  6. from .base_llm import BaseLLM
  7. from pyopenagi.utils.chat_template import Response
  8. class ClaudeLLM(BaseLLM):
  9. """
  10. ClaudeLLM class for interacting with Anthropic's Claude models.
  11. This class provides methods for processing queries using Claude models,
  12. including handling of tool calls and message formatting.
  13. Attributes:
  14. model (anthropic.Anthropic): The Anthropic client for API calls.
  15. tokenizer (None): Placeholder for tokenizer, not used in this implementation.
  16. """
  17. def __init__(self, llm_name: str,
  18. max_gpu_memory: Dict[int, str] = None,
  19. eval_device: str = None,
  20. max_new_tokens: int = 256,
  21. log_mode: str = "console"):
  22. """
  23. Initialize the ClaudeLLM instance.
  24. Args:
  25. llm_name (str): Name of the Claude model to use.
  26. max_gpu_memory (Dict[int, str], optional): GPU memory configuration.
  27. eval_device (str, optional): Device for evaluation.
  28. max_new_tokens (int, optional): Maximum number of new tokens to generate.
  29. log_mode (str, optional): Logging mode, defaults to "console".
  30. """
  31. super().__init__(llm_name,
  32. max_gpu_memory=max_gpu_memory,
  33. eval_device=eval_device,
  34. max_new_tokens=max_new_tokens,
  35. log_mode=log_mode)
  36. def load_llm_and_tokenizer(self) -> None:
  37. """
  38. Load the Anthropic client for API calls.
  39. """
  40. self.model = anthropic.Anthropic()
  41. self.tokenizer = None
  42. def process(self, agent_request: Any, temperature: float = 0.0) -> None:
  43. """
  44. Process a query using the Claude model.
  45. Args:
  46. agent_request (Any): The agent process containing the query and tools.
  47. temperature (float, optional): Sampling temperature for generation.
  48. Raises:
  49. AssertionError: If the model name doesn't contain 'claude'.
  50. anthropic.APIError: If there's an error with the Anthropic API call.
  51. Exception: For any other unexpected errors.
  52. """
  53. assert re.search(r'claude', self.model_name, re.IGNORECASE), "Model name must contain 'claude'"
  54. agent_request.set_status("executing")
  55. agent_request.set_start_time(time.time())
  56. messages = agent_request.query.messages
  57. tools = agent_request.query.tools
  58. self.logger.log(f"{messages}", level="info")
  59. self.logger.log(f"{agent_request.agent_name} is switched to executing.", level="executing")
  60. if tools:
  61. messages = self.tool_calling_input_format(messages, tools)
  62. anthropic_messages = self._convert_to_anthropic_messages(messages)
  63. self.logger.log(f"{anthropic_messages}", level="info")
  64. try:
  65. response = self.model.messages.create(
  66. model=self.model_name,
  67. messages=anthropic_messages,
  68. max_tokens=self.max_new_tokens,
  69. temperature=temperature
  70. )
  71. response_message = response.content[0].text
  72. self.logger.log(f"API Response: {response_message}", level="info")
  73. tool_calls = self.parse_tool_calls(response_message) if tools else None
  74. response = Response(
  75. response_message=response_message,
  76. tool_calls=tool_calls
  77. )
  78. # agent_request.set_response(
  79. # Response(
  80. # response_message=response_message,
  81. # tool_calls=tool_calls
  82. # )
  83. # )
  84. except anthropic.APIError as e:
  85. error_message = f"Anthropic API error: {str(e)}"
  86. self.logger.log(error_message, level="warning")
  87. response = Response(
  88. response_message=f"Error: {str(e)}",
  89. tool_calls=None
  90. )
  91. # agent_request.set_response(
  92. # Response(
  93. # response_message=f"Error: {str(e)}",
  94. # tool_calls=None
  95. # )
  96. # )
  97. except Exception as e:
  98. error_message = f"Unexpected error: {str(e)}"
  99. self.logger.log(error_message, level="warning")
  100. # agent_request.set_response(
  101. # Response(
  102. # response_message=f"Unexpected error: {str(e)}",
  103. # tool_calls=None
  104. # )
  105. # )
  106. response = Response(
  107. response_message=f"Unexpected error: {str(e)}",
  108. tool_calls=None
  109. )
  110. return response
  111. # agent_request.set_status("done")
  112. # agent_request.set_end_time(time.time())
  113. def _convert_to_anthropic_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
  114. """
  115. Convert messages to the format expected by the Anthropic API.
  116. Args:
  117. messages (List[Dict[str, str]]): Original messages.
  118. Returns:
  119. List[Dict[str, str]]: Converted messages for Anthropic API.
  120. """
  121. anthropic_messages = []
  122. for message in messages:
  123. if message['role'] == 'system':
  124. anthropic_messages.append({"role": "user", "content": f"System: {message['content']}"})
  125. anthropic_messages.append({"role": "assistant", "content": "Understood. I will follow these instructions."})
  126. else:
  127. anthropic_messages.append({
  128. "role": "user" if message['role'] == 'user' else "assistant",
  129. "content": message['content']
  130. })
  131. return anthropic_messages
  132. def tool_calling_output_format(self, tool_calling_messages: str) -> List[Dict[str, Any]]:
  133. """
  134. Parse the tool calling output from the model's response.
  135. Args:
  136. tool_calling_messages (str): The model's response containing tool calls.
  137. Returns:
  138. List[Dict[str, Any]]: Parsed tool calls, or None if parsing fails.
  139. """
  140. try:
  141. json_content = json.loads(tool_calling_messages)
  142. if isinstance(json_content, list) and len(json_content) > 0 and 'name' in json_content[0]:
  143. return json_content
  144. except json.JSONDecodeError:
  145. pass
  146. return super().tool_calling_output_format(tool_calling_messages)