server.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. try:
  2. import flask # noqa
  3. except ImportError as e:
  4. msg = (
  5. "Flask not found. You probably haven't installed the dependencies for SWE-agent. "
  6. "Please go to the root of the repository and run `pip install -e .`"
  7. )
  8. raise RuntimeError(msg) from e
  9. from contextlib import redirect_stderr, redirect_stdout
  10. import copy
  11. import json
  12. import os
  13. from pathlib import Path
  14. import time
  15. import traceback
  16. from typing import Any, Dict
  17. from flask import Flask, render_template, request, make_response
  18. import sys
  19. import yaml
  20. from sweagent import CONFIG_DIR, PACKAGE_DIR
  21. from sweagent.agent.agents import AgentArguments
  22. from sweagent.agent.models import ModelArguments
  23. from sweagent.api.utils import ThreadWithExc, AttrDict
  24. from sweagent.environment.swe_env import EnvironmentArguments
  25. from sweagent.api.hooks import EnvUpdateHook, WebUpdate, MainUpdateHook, AgentUpdateHook
  26. import sweagent.environment.utils as env_utils
  27. from flask_socketio import SocketIO
  28. from flask_cors import CORS
  29. from flask import session
  30. from uuid import uuid4
  31. import tempfile
  32. import atexit
  33. # baaaaaaad
  34. sys.path.append(str(PACKAGE_DIR.parent))
  35. from run import ActionsArguments, ScriptArguments, Main
  36. app = Flask(__name__)
  37. CORS(app)
  38. socketio = SocketIO(app, cors_allowed_origins="*")
  39. # Setting these variables outside of `if __name__ == "__main__"` because when run Flask server with
  40. # `flask run` it will skip the if block. Therefore, the app will return an error for missing `secret_key`
  41. # Setting it here will allow both `flask run` and `python server.py` to work
  42. app.secret_key = "super secret key"
  43. app.config["SESSION_TYPE"] = "memcache"
  44. THREADS: Dict[str, "MainThread"] = {}
  45. os.environ["SWE_AGENT_EXPERIMENTAL_COMMUNICATE"] = "1"
  46. env_utils.START_UP_DELAY = 1
  47. def ensure_session_id_set():
  48. """Ensures a session ID is set for this user"""
  49. session_id = session.get("session_id", None)
  50. if not session_id:
  51. session_id = uuid4().hex
  52. session["session_id"] = session_id
  53. return session_id
  54. class MainThread(ThreadWithExc):
  55. def __init__(self, settings: ScriptArguments, wu: WebUpdate):
  56. super().__init__()
  57. self._wu = wu
  58. self._settings = settings
  59. def run(self) -> None:
  60. # fixme: This actually redirects all output from all threads to the socketio, which is not what we want
  61. with redirect_stdout(self._wu.log_stream):
  62. with redirect_stderr(self._wu.log_stream):
  63. try:
  64. main = Main(self._settings)
  65. main.add_hook(MainUpdateHook(self._wu))
  66. main.agent.add_hook(AgentUpdateHook(self._wu))
  67. main.env.add_hook(EnvUpdateHook(self._wu))
  68. main.main()
  69. except Exception as e:
  70. short_msg = str(e)
  71. max_len = 350
  72. if len(short_msg) > max_len:
  73. short_msg = f"{short_msg[:max_len]}... (see log for details)"
  74. traceback_str = traceback.format_exc()
  75. self._wu.up_log(traceback_str)
  76. self._wu.up_agent(f"Error: {short_msg}")
  77. self._wu.up_banner("Critical error: " + short_msg)
  78. self._wu.finish_run()
  79. raise
  80. def stop(self):
  81. while self.is_alive():
  82. self.raise_exc(SystemExit)
  83. time.sleep(0.1)
  84. self._wu.finish_run()
  85. self._wu.up_agent("Run stopped by user")
  86. @app.route("/")
  87. def index():
  88. return render_template("index.html")
  89. @socketio.on("connect")
  90. def handle_connect():
  91. print("Client connected")
  92. def write_env_yaml(data) -> str:
  93. data: Any = AttrDict(copy.deepcopy(dict(data)))
  94. if not data.install_command_active:
  95. data.install = ""
  96. del data.install_command_active
  97. data.pip_packages = [p.strip() for p in data.pip_packages.split("\n") if p.strip()]
  98. path = Path(tempfile.NamedTemporaryFile(delete=False, suffix=".yml").name)
  99. # Make sure that the file is deleted when the program exits
  100. atexit.register(path.unlink)
  101. path.write_text(yaml.dump(dict(data)))
  102. return str(path)
  103. @app.route("/run", methods=["GET", "OPTIONS"])
  104. def run():
  105. session_id = ensure_session_id_set()
  106. if request.method == "OPTIONS": # CORS preflight
  107. return _build_cors_preflight_response()
  108. # While we're running as a local UI, let's make sure that there's at most
  109. # one run at a time
  110. global THREADS
  111. for thread in THREADS.values():
  112. if thread.is_alive():
  113. thread.stop()
  114. wu = WebUpdate(socketio)
  115. wu.up_agent("Starting the run")
  116. # Use Any type to silence annoying false positives from mypy
  117. run: Any = AttrDict.from_nested_dicts(json.loads(request.args["runConfig"]))
  118. print(run)
  119. print(run.environment)
  120. print(run.environment.base_commit)
  121. model_name: str = run.agent.model.model_name
  122. environment_setup = ""
  123. environment_input_type = run.environment.environment_setup.input_type
  124. if environment_input_type == "manual":
  125. environment_setup = str(write_env_yaml(run.environment.environment_setup.manual))
  126. elif environment_input_type == "script_path":
  127. environment_setup = run.environment.environment_setup.script_path["script_path"]
  128. else:
  129. raise ValueError(f"Unknown input type: {environment_input_type}")
  130. if not environment_setup.strip():
  131. environment_setup = None
  132. test_run: bool = run.extra.test_run
  133. if test_run:
  134. model_name = "instant_empty_submit"
  135. defaults = ScriptArguments(
  136. suffix="",
  137. environment=EnvironmentArguments(
  138. image_name="sweagent/swe-agent:latest",
  139. data_path=run.environment.data_path,
  140. base_commit=run.environment.base_commit,
  141. split="dev",
  142. verbose=True,
  143. install_environment=True,
  144. repo_path=run.environment.repo_path,
  145. environment_setup=environment_setup,
  146. ),
  147. skip_existing=False,
  148. agent=AgentArguments(
  149. model=ModelArguments(
  150. model_name=model_name,
  151. total_cost_limit=0.0,
  152. per_instance_cost_limit=3.0,
  153. temperature=0.0,
  154. top_p=0.95,
  155. ),
  156. config_file=CONFIG_DIR / "default_from_url.yaml",
  157. ),
  158. actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
  159. raise_exceptions=True,
  160. )
  161. thread = MainThread(defaults, wu)
  162. THREADS[session_id] = thread
  163. thread.start()
  164. return "Commands are being executed", 202
  165. @app.route("/stop")
  166. def stop():
  167. session_id = ensure_session_id_set()
  168. global THREADS
  169. print(f"Stopping session {session_id}")
  170. print(THREADS)
  171. thread = THREADS.get(session_id)
  172. if thread and thread.is_alive():
  173. print(f"Thread {thread} is alive")
  174. thread.stop()
  175. else:
  176. print(f"Thread {thread} is not alive")
  177. return "Stopping computation", 202
  178. def _build_cors_preflight_response():
  179. response = make_response()
  180. response.headers.add("Access-Control-Allow-Origin", "*")
  181. response.headers.add("Access-Control-Allow-Headers", "*")
  182. response.headers.add("Access-Control-Allow-Methods", "*")
  183. return response
  184. if __name__ == "__main__":
  185. app.debug = True
  186. socketio.run(app, port=8000, debug=True)