test_run.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from __future__ import annotations
  2. import dataclasses
  3. import json
  4. import logging
  5. import os
  6. import subprocess
  7. from pathlib import Path
  8. from typing import Any
  9. import pytest
  10. import docker
  11. from run import ActionsArguments, Main, MainHook, OpenPRHook, ScriptArguments
  12. from sweagent.agent.agents import Agent, AgentArguments, AgentHook
  13. from sweagent.agent.models import ModelArguments
  14. from sweagent.environment.swe_env import EnvironmentArguments, SWEEnv
  15. @pytest.mark.slow()
  16. def test_run_cli_help():
  17. args = [
  18. "python",
  19. "run.py",
  20. "--help",
  21. ]
  22. subprocess.run(args, check=True)
  23. @pytest.fixture()
  24. def open_pr_hook_init_for_sop():
  25. hook = OpenPRHook()
  26. hook._token = os.environ.get("GITHUB_TOKEN", "")
  27. hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/1"
  28. hook._open_pr = True
  29. hook._skip_if_commits_reference_issue = True
  30. return hook
  31. @pytest.fixture()
  32. def info_dict():
  33. return {
  34. "submission": "asdf",
  35. "exit_status": "submitted",
  36. }
  37. def test_should_open_pr_fail_submission(open_pr_hook_init_for_sop, info_dict):
  38. hook = open_pr_hook_init_for_sop
  39. info_dict["submission"] = None
  40. assert not hook.should_open_pr(info_dict)
  41. def test_should_open_pr_fail_exit(open_pr_hook_init_for_sop, info_dict):
  42. hook = open_pr_hook_init_for_sop
  43. info_dict["exit_status"] = "fail"
  44. assert not hook.should_open_pr(info_dict)
  45. def test_should_open_pr_fail_invalid_url(open_pr_hook_init_for_sop, info_dict):
  46. hook = open_pr_hook_init_for_sop
  47. hook._data_path = "asdf"
  48. assert not hook.should_open_pr(info_dict)
  49. def test_should_open_pr_fail_closed(open_pr_hook_init_for_sop, info_dict):
  50. hook = open_pr_hook_init_for_sop
  51. hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/16"
  52. assert not hook.should_open_pr(info_dict)
  53. def test_should_open_pr_fail_assigned(open_pr_hook_init_for_sop, info_dict):
  54. hook = open_pr_hook_init_for_sop
  55. hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/17"
  56. assert not hook.should_open_pr(info_dict)
  57. def test_should_open_pr_fail_locked(open_pr_hook_init_for_sop, info_dict):
  58. hook = open_pr_hook_init_for_sop
  59. hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/18"
  60. assert not hook.should_open_pr(info_dict)
  61. def test_should_open_pr_fail_has_pr(open_pr_hook_init_for_sop, info_dict):
  62. hook = open_pr_hook_init_for_sop
  63. hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/19"
  64. assert not hook.should_open_pr(info_dict)
  65. def test_should_open_pr_success_has_pr_override(open_pr_hook_init_for_sop, info_dict):
  66. hook = open_pr_hook_init_for_sop
  67. hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/19"
  68. hook._skip_if_commits_reference_issue = False
  69. assert hook.should_open_pr(info_dict)
  70. class RaisesExceptionHook(MainHook):
  71. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  72. msg = "test exception"
  73. raise ValueError(msg)
  74. @pytest.fixture()
  75. def test_script_args():
  76. return ScriptArguments(
  77. suffix="",
  78. environment=EnvironmentArguments(
  79. image_name="sweagent/swe-agent:latest",
  80. data_path="https://github.com/klieret/swe-agent-test-repo/issues/1",
  81. split="dev",
  82. verbose=True,
  83. install_environment=True,
  84. ),
  85. skip_existing=False,
  86. agent=AgentArguments(
  87. model=ModelArguments(
  88. model_name="instant_empty_submit",
  89. total_cost_limit=0.0,
  90. per_instance_cost_limit=3.0,
  91. temperature=0.0,
  92. top_p=0.95,
  93. ),
  94. config_file=Path("config/default.yaml"),
  95. ),
  96. actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
  97. raise_exceptions=True,
  98. print_config=False,
  99. )
  100. @pytest.mark.slow()
  101. def test_exception_raised(test_script_args: ScriptArguments):
  102. assert test_script_args.raise_exceptions
  103. main = Main(test_script_args)
  104. main.add_hook(RaisesExceptionHook())
  105. with pytest.raises(ValueError, match="test exception"):
  106. main.main()
  107. @pytest.mark.slow()
  108. class CreateFakeLogFile(MainHook):
  109. """Testing the skip functionality"""
  110. def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
  111. self._traj_dir = traj_dir
  112. (traj_dir / "args.yaml").write_text("asdf")
  113. def on_instance_start(self, *, index: int, instance: dict[str, Any]):
  114. instance_id = instance["instance_id"]
  115. dct = {
  116. "info": {"exit_status": "submitted"},
  117. }
  118. (self._traj_dir / f"{instance_id}.traj").write_text(json.dumps(dct))
  119. @pytest.mark.slow()
  120. def test_existing_corrupted_args(test_script_args: ScriptArguments):
  121. main = Main(test_script_args)
  122. main.add_hook(CreateFakeLogFile())
  123. main.main()
  124. @pytest.mark.slow()
  125. def test_main_hook(test_script_args: ScriptArguments):
  126. main = Main(test_script_args)
  127. main.add_hook(MainHook())
  128. main.main()
  129. @pytest.mark.slow()
  130. def test_agent_with_hook(test_script_args: ScriptArguments):
  131. main = Main(test_script_args)
  132. main.agent.add_hook(AgentHook())
  133. main.main()
  134. PERSISTENT_CONTAINER_NAME = "sweagent-test-persistent-container"
  135. @pytest.fixture()
  136. def _cleanup_persistent_container():
  137. yield
  138. client = docker.from_env()
  139. container = client.containers.get(PERSISTENT_CONTAINER_NAME)
  140. container.remove(force=True)
  141. @pytest.mark.slow()
  142. @pytest.mark.usefixtures("_cleanup_persistent_container")
  143. def test_agent_persistent_container(test_script_args: ScriptArguments, capsys):
  144. test_script_args = dataclasses.replace(
  145. test_script_args,
  146. environment=dataclasses.replace(test_script_args.environment, container_name=PERSISTENT_CONTAINER_NAME),
  147. )
  148. assert test_script_args.environment.verbose
  149. main = Main(test_script_args)
  150. assert main.env.logger.isEnabledFor(logging.DEBUG)
  151. main.main()
  152. captured = capsys.readouterr()
  153. print("---")
  154. print(captured.out)
  155. print("---")
  156. print(captured.err)
  157. print("---")
  158. text = captured.out + captured.err
  159. assert "Trying to clone from non-mirror..." in text
  160. assert "Falling back to full cloning method" in text