base_agent.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import os
  2. import json
  3. import time
  4. from threading import Thread
  5. import threading
  6. from ..utils.logger import AgentLogger
  7. from ..utils.chat_template import Query
  8. import importlib
  9. from aios.hooks.request import send_request
  10. class BaseAgent:
  11. def __init__(self, agent_name, task_input, log_mode: str):
  12. # super().__init__()
  13. self.agent_name = agent_name
  14. self.config = self.load_config()
  15. self.tool_names = self.config["tools"]
  16. self.plan_max_fail_times = 3
  17. self.tool_call_max_fail_times = 3
  18. # self.agent_process_factory = agent_process_factory
  19. self.tool_list = dict()
  20. self.tools = []
  21. self.tool_info = (
  22. []
  23. ) # simplified information of the tool: {"name": "xxx", "description": "xxx"}
  24. self.load_tools(self.tool_names)
  25. self.start_time = None
  26. self.end_time = None
  27. self.request_waiting_times: list = []
  28. self.request_turnaround_times: list = []
  29. self.task_input = task_input
  30. self.messages = []
  31. self.workflow_mode = "manual" # (mannual, automatic)
  32. self.rounds = 0
  33. self.log_mode = log_mode
  34. self.logger = self.setup_logger()
  35. self.set_status("active")
  36. self.set_created_time(time.time())
  37. def run(self):
  38. """Execute each step to finish the task."""
  39. # self.set_aid(threading.get_ident())
  40. self.logger.log(
  41. f"{self.agent_name} starts running. Agent ID is {self.get_aid()}\n",
  42. level="info",
  43. )
  44. # can be customization
  45. def build_system_instruction(self):
  46. pass
  47. def check_workflow(self, message):
  48. try:
  49. # print(f"Workflow message: {message}")
  50. workflow = json.loads(message)
  51. if not isinstance(workflow, list):
  52. return None
  53. for step in workflow:
  54. if "message" not in step or "tool_use" not in step:
  55. return None
  56. return workflow
  57. except json.JSONDecodeError:
  58. return None
  59. def automatic_workflow(self):
  60. for i in range(self.plan_max_fail_times):
  61. response, start_times, end_times, waiting_times, turnaround_times = send_request(
  62. agent_name = self.agent_name,
  63. query=Query(
  64. messages=self.messages, tools=None, message_return_type="json"
  65. )
  66. )
  67. if self.rounds == 0:
  68. self.set_start_time(start_times[0])
  69. self.request_waiting_times.extend(waiting_times)
  70. self.request_turnaround_times.extend(turnaround_times)
  71. workflow = self.check_workflow(response.response_message)
  72. self.rounds += 1
  73. if workflow:
  74. return workflow
  75. else:
  76. self.messages.append(
  77. {
  78. "role": "assistant",
  79. "content": f"Fail {i+1} times to generate a valid plan. I need to regenerate a plan",
  80. }
  81. )
  82. return None
  83. def manual_workflow(self):
  84. pass
  85. def check_path(self, tool_calls):
  86. script_path = os.path.abspath(__file__)
  87. save_dir = os.path.join(
  88. os.path.dirname(script_path), "output"
  89. ) # modify the customized output path for saving outputs
  90. if not os.path.exists(save_dir):
  91. os.makedirs(save_dir)
  92. for tool_call in tool_calls:
  93. try:
  94. for k in tool_call["parameters"]:
  95. if "path" in k:
  96. path = tool_call["parameters"][k]
  97. if not path.startswith(save_dir):
  98. tool_call["parameters"][k] = os.path.join(
  99. save_dir, os.path.basename(path)
  100. )
  101. except Exception:
  102. continue
  103. return tool_calls
  104. def snake_to_camel(self, snake_str):
  105. components = snake_str.split("_")
  106. return "".join(x.title() for x in components)
  107. def load_tools(self, tool_names):
  108. if tool_names == "None":
  109. return
  110. for tool_name in tool_names:
  111. org, name = tool_name.split("/")
  112. module_name = ".".join(["pyopenagi", "tools", org, name])
  113. class_name = self.snake_to_camel(name)
  114. tool_module = importlib.import_module(module_name)
  115. tool_class = getattr(tool_module, class_name)
  116. self.tool_list[name] = tool_class()
  117. tool_format = tool_class().get_tool_call_format()
  118. self.tools.append(tool_format)
  119. self.tool_info.append(
  120. {
  121. "name": tool_format["function"]["name"],
  122. "description": tool_format["function"]["description"],
  123. }
  124. )
  125. def pre_select_tools(self, tool_names):
  126. pre_selected_tools = []
  127. for tool_name in tool_names:
  128. for tool in self.tools:
  129. if tool["function"]["name"] == tool_name:
  130. pre_selected_tools.append(tool)
  131. break
  132. return pre_selected_tools
  133. def setup_logger(self):
  134. logger = AgentLogger(self.agent_name, self.log_mode)
  135. return logger
  136. def load_config(self):
  137. script_path = os.path.abspath(__file__)
  138. script_dir = os.path.dirname(script_path)
  139. config_file = os.path.join(script_dir, self.agent_name, "config.json")
  140. with open(config_file, "r") as f:
  141. config = json.load(f)
  142. return config
  143. def set_aid(self, aid):
  144. self.aid = aid
  145. def get_aid(self):
  146. return self.aid
  147. def get_agent_name(self):
  148. return self.agent_name
  149. def set_status(self, status):
  150. """
  151. Status type: Waiting, Running, Done, Inactive
  152. """
  153. self.status = status
  154. def get_status(self):
  155. return self.status
  156. def set_created_time(self, time):
  157. self.created_time = time
  158. def get_created_time(self):
  159. return self.created_time
  160. def set_start_time(self, time):
  161. self.start_time = time
  162. def get_start_time(self):
  163. return self.start_time
  164. def set_end_time(self, time):
  165. self.end_time = time
  166. def get_end_time(self):
  167. return self.end_time