server.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from collections import OrderedDict
  2. from fastapi import Depends, FastAPI, Query
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from aios.hooks.llm import useFIFOScheduler, useFactory, useKernel
  5. from aios.hooks.types.llm import AgentSubmitDeclaration, LLMParams
  6. from aios.hooks.parser import string
  7. from aios.core.schema import CoreSchema
  8. from aios.hooks.types.parser import ParserQuery
  9. from aios.utils.utils import (
  10. parse_global_args,
  11. )
  12. from pyopenagi.manager.manager import AgentManager
  13. from aios.utils.state import useGlobalState
  14. from dotenv import load_dotenv
  15. import atexit
  16. import json
  17. load_dotenv()
  18. app = FastAPI()
  19. app.add_middleware(
  20. CORSMiddleware,
  21. allow_origins=["*"],
  22. allow_credentials=True,
  23. allow_methods=["*"],
  24. allow_headers=["*"],
  25. )
  26. getLLMState, setLLMState, setLLMCallback = useGlobalState()
  27. getFactory, setFactory, setFactoryCallback = useGlobalState()
  28. getManager, setManager, setManagerCallback = useGlobalState()
  29. setManager(AgentManager("https://my.aios.foundation"))
  30. parser = parse_global_args()
  31. args = parser.parse_args()
  32. # check if the llm information was specified in args
  33. try:
  34. with open("aios_config.json", "r") as f:
  35. aios_config = json.load(f)
  36. # print to stderr
  37. print("Loaded aios_config.json, ignoring args", file=sys.stderr)
  38. llm_cores = aios_config["llm_cores"][0]
  39. # only check aios_config.json
  40. setLLMState(
  41. useKernel(
  42. llm_name=llm_cores.get("llm_name"),
  43. max_gpu_memory=llm_cores.get("max_gpu_memory"),
  44. eval_device=llm_cores.get("eval_device"),
  45. max_new_tokens=llm_cores.get("max_new_tokens"),
  46. log_mode="console",
  47. use_backend=llm_cores.get("use_backend")
  48. )
  49. )
  50. except FileNotFoundError:
  51. aios_config = {}
  52. # only check args
  53. setLLMState(
  54. useKernel(
  55. llm_name=args.llm_name,
  56. max_gpu_memory=args.max_gpu_memory,
  57. eval_device=args.eval_device,
  58. max_new_tokens=args.max_new_tokens,
  59. log_mode=args.log_mode,
  60. use_backend=args.use_backend
  61. )
  62. )
  63. startScheduler, stopScheduler = useFIFOScheduler(
  64. llm=getLLMState(), log_mode=args.log_mode, get_queue_message=None
  65. )
  66. submitAgent, awaitAgentExecution = useFactory(log_mode=args.log_mode, max_workers=500)
  67. setFactory({"submit": submitAgent, "execute": awaitAgentExecution})
  68. startScheduler()
  69. @app.post("/set_kernel")
  70. async def set_kernel(req: LLMParams):
  71. setLLMState(useKernel(**req))
  72. @app.post("/add_agent")
  73. async def add_agent(
  74. req: AgentSubmitDeclaration,
  75. factory: dict = Depends(getFactory),
  76. ):
  77. try:
  78. submit_agent = factory.get("submit")
  79. process_id = submit_agent(agent_name=req.agent_name, task_input=req.task_input)
  80. return {"success": True, "agent": req.agent_name, "pid": process_id}
  81. except Exception as e:
  82. return {"success": False, "exception": f"{e}"}
  83. @app.get("/execute_agent")
  84. async def execute_agent(
  85. pid: int = Query(..., description="The process ID"),
  86. factory: dict = Depends(getFactory),
  87. ):
  88. try:
  89. response = factory.get("execute")(pid)
  90. return {"success": True, "response": response}
  91. except Exception as e:
  92. print("Got an exception while executing agent: ", e)
  93. return {"success": False, "exception": f"{e}"}
  94. @app.post("/agent_parser")
  95. async def parse_query(req: ParserQuery):
  96. parser_schema = CoreSchema()
  97. parser_schema.add_field("agent_name", string, "name of agent").add_field(
  98. "phrase", string, "agent instruction"
  99. )
  100. @app.get("/get_all_agents")
  101. async def get_all_agents():
  102. manager: AgentManager = getManager()
  103. def transform_string(input_string: str):
  104. return "/".join(input_string.split("/")[:-1])
  105. agents = manager.list_available_agents()
  106. print(agents)
  107. agent_names = []
  108. seen = OrderedDict()
  109. for i, a in enumerate(reversed(agents)):
  110. transformed = transform_string(a.get("agent"))
  111. if transformed not in seen:
  112. seen[transformed] = i
  113. agent_names.append(transformed)
  114. # Create the final list with unique display names but original IDs
  115. _ = [{"id": agents[i].get("agent"), "display": name} for name, i in seen.items()]
  116. return {"agents": _}
  117. def cleanup():
  118. stopScheduler()
  119. atexit.register(cleanup)