test_run.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from __future__ import annotations
  2. import dataclasses
  3. import json
  4. import logging
  5. import os
  6. import subprocess
  7. from pathlib import Path
  8. from typing import Any
  9. import pytest
  10. import docker
  11. from run import ActionsArguments, Main, MainHook, OpenPRHook, ScriptArguments
  12. from sweagent.agent.agents import Agent, AgentArguments, AgentHook
  13. from sweagent.agent.models import ModelArguments
  14. from sweagent.environment.swe_env import EnvironmentArguments, SWEEnv
  15. @pytest.mark.slow
  16. def test_run_cli_help():
  17. args = [
  18. "python",
  19. "run.py",
  20. "--help",
  21. ]
  22. subprocess.run(args, check=True)
  23. @pytest.fixture
  24. def open_pr_hook_init_for_sop():
  25. hook = OpenPRHook()
  26. hook._token = os.environ.get("GITHUB_TOKEN", "")
  27. hook._data_path = "https://github.com/swe-agent/test-repo/issues/1"
  28. hook._open_pr = True
  29. hook._skip_if_commits_reference_issue = True
  30. return hook
  31. @pytest.fixture
  32. def info_dict():
  33. return {
  34. "submission": "asdf",
  35. "exit_status": "submitted",
  36. }
  37. def test_should_open_pr_fail_submission(open_pr_hook_init_for_sop, info_dict):
  38. hook = open_pr_hook_init_for_sop
  39. info_dict["submission"] = None
  40. assert not hook.should_open_pr(info_dict)
  41. def test_should_open_pr_fail_exit(open_pr_hook_init_for_sop, info_dict):
  42. hook = open_pr_hook_init_for_sop
  43. info_dict["exit_status"] = "fail"
  44. assert not hook.should_open_pr(info_dict)
  45. def test_should_open_pr_fail_invalid_url(open_pr_hook_init_for_sop, info_dict):
  46. hook = open_pr_hook_init_for_sop
  47. hook._data_path = "asdf"
  48. assert not hook.should_open_pr(info_dict)
  49. def test_should_open_pr_fail_closed(open_pr_hook_init_for_sop, info_dict):
  50. hook = open_pr_hook_init_for_sop
  51. hook._data_path = "https://github.com/swe-agent/test-repo/issues/16"
  52. assert not hook.should_open_pr(info_dict)
  53. def test_should_open_pr_fail_assigned(open_pr_hook_init_for_sop, info_dict):
  54. hook = open_pr_hook_init_for_sop
  55. hook._data_path = "https://github.com/swe-agent/test-repo/issues/17"
  56. assert not hook.should_open_pr(info_dict)
  57. def test_should_open_pr_fail_locked(open_pr_hook_init_for_sop, info_dict):
  58. hook = open_pr_hook_init_for_sop
  59. hook._data_path = "https://github.com/swe-agent/test-repo/issues/18"
  60. assert not hook.should_open_pr(info_dict)
  61. def test_should_open_pr_fail_has_pr(open_pr_hook_init_for_sop, info_dict):
  62. hook = open_pr_hook_init_for_sop
  63. hook._data_path = "https://github.com/swe-agent/test-repo/issues/19"
  64. assert not hook.should_open_pr(info_dict)
  65. def test_should_open_pr_success_has_pr_override(open_pr_hook_init_for_sop, info_dict):
  66. hook = open_pr_hook_init_for_sop
  67. hook._data_path = "https://github.com/swe-agent/test-repo/issues/19"
  68. hook._skip_if_commits_reference_issue = False
  69. assert hook.should_open_pr(info_dict)
  70. class RaisesExceptionHook(MainHook):
  71. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  72. msg = "test exception"
  73. raise ValueError(msg)
  74. @pytest.fixture
  75. def test_script_args():
  76. return ScriptArguments(
  77. suffix="",
  78. environment=EnvironmentArguments(
  79. image_name="sweagent/swe-agent:latest",
  80. data_path="https://github.com/swe-agent/test-repo/issues/1",
  81. split="dev",
  82. verbose=True,
  83. install_environment=True,
  84. ),
  85. skip_existing=False,
  86. agent=AgentArguments(
  87. model=ModelArguments(
  88. model_name="instant_empty_submit",
  89. ),
  90. config_file=Path("config/default.yaml"),
  91. ),
  92. actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
  93. raise_exceptions=True,
  94. print_config=False,
  95. )
  96. @pytest.mark.slow
  97. def test_exception_raised(test_script_args: ScriptArguments):
  98. assert test_script_args.raise_exceptions
  99. main = Main(test_script_args)
  100. main.add_hook(RaisesExceptionHook())
  101. with pytest.raises(ValueError, match="test exception"):
  102. main.main()
  103. @pytest.mark.slow
  104. class CreateFakeLogFile(MainHook):
  105. """Testing the skip functionality"""
  106. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  107. self._traj_dir = traj_dir
  108. (traj_dir / "args.yaml").write_text("asdf")
  109. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  110. instance_id = instance["instance_id"]
  111. dct = {
  112. "info": {"exit_status": "submitted"},
  113. }
  114. (self._traj_dir / f"{instance_id}.traj").write_text(json.dumps(dct))
  115. @pytest.mark.slow
  116. def test_existing_corrupted_args(test_script_args: ScriptArguments):
  117. main = Main(test_script_args)
  118. main.add_hook(CreateFakeLogFile())
  119. main.main()
  120. @pytest.mark.slow
  121. def test_main_hook(test_script_args: ScriptArguments):
  122. main = Main(test_script_args)
  123. main.add_hook(MainHook())
  124. main.main()
  125. @pytest.mark.slow
  126. def test_agent_with_hook(test_script_args: ScriptArguments):
  127. main = Main(test_script_args)
  128. main.agent.add_hook(AgentHook())
  129. main.main()
  130. PERSISTENT_CONTAINER_NAME = "sweagent-test-persistent-container"
  131. @pytest.fixture
  132. def _cleanup_persistent_container():
  133. yield
  134. client = docker.from_env()
  135. container = client.containers.get(PERSISTENT_CONTAINER_NAME)
  136. container.remove(force=True)
  137. @pytest.mark.slow
  138. @pytest.mark.usefixtures("_cleanup_persistent_container")
  139. def test_agent_persistent_container(test_script_args: ScriptArguments, capsys):
  140. test_script_args = dataclasses.replace(
  141. test_script_args,
  142. environment=dataclasses.replace(test_script_args.environment, container_name=PERSISTENT_CONTAINER_NAME),
  143. )
  144. assert test_script_args.environment.verbose
  145. main = Main(test_script_args)
  146. assert main.env.logger.isEnabledFor(logging.DEBUG)
  147. main.main()
  148. captured = capsys.readouterr()
  149. print("---")
  150. print(captured.out)
  151. print("---")
  152. print(captured.err)
  153. print("---")
  154. text = captured.out + captured.err
  155. assert "Trying to clone from non-mirror..." in text
  156. assert "Falling back to full cloning method" in text
  157. def test_dummy_interactive_session(test_script_args: ScriptArguments):
  158. test_script_args = dataclasses.replace(
  159. test_script_args,
  160. agent=AgentArguments(
  161. model=ModelArguments(
  162. model_name="instant_empty_submit",
  163. ),
  164. config_file=Path("tests", "test_data", "config_files", "dummy_interactive.yaml"),
  165. ),
  166. )
  167. print(test_script_args.agent.config.command_docs) # type: ignore
  168. main = Main(test_script_args)
  169. env = main.env
  170. env.reset()
  171. main.agent.set_environment_vars(env, {})
  172. action_obs = [
  173. ("doesntexit", "command not found"),
  174. ("dummy_stop", "is not running"),
  175. ("dummy_send", "is not running"),
  176. ("dummy_start", "Started interactive dummy command"),
  177. ("dummy_start", "Interactive session already open"),
  178. ("dummy_send asdf", "asdf"),
  179. ("dummy_stop", "stopped successfully"),
  180. ("dummy_stop", "is not running"),
  181. ]
  182. for action, expected_observation in action_obs:
  183. observation, *_ = env.step(action)
  184. assert observation is not None
  185. assert expected_observation in observation, observation
  186. env.close()