123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- from __future__ import annotations
- import dataclasses
- import json
- import logging
- import os
- import subprocess
- from pathlib import Path
- from typing import Any
- import pytest
- import docker
- from run import ActionsArguments, Main, MainHook, OpenPRHook, ScriptArguments
- from sweagent.agent.agents import Agent, AgentArguments, AgentHook
- from sweagent.agent.models import ModelArguments
- from sweagent.environment.swe_env import EnvironmentArguments, SWEEnv
- @pytest.mark.slow()
- def test_run_cli_help():
- args = [
- "python",
- "run.py",
- "--help",
- ]
- subprocess.run(args, check=True)
- @pytest.fixture()
- def open_pr_hook_init_for_sop():
- hook = OpenPRHook()
- hook._token = os.environ.get("GITHUB_TOKEN", "")
- hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/1"
- hook._open_pr = True
- hook._skip_if_commits_reference_issue = True
- return hook
- @pytest.fixture()
- def info_dict():
- return {
- "submission": "asdf",
- "exit_status": "submitted",
- }
- def test_should_open_pr_fail_submission(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- info_dict["submission"] = None
- assert not hook.should_open_pr(info_dict)
- def test_should_open_pr_fail_exit(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- info_dict["exit_status"] = "fail"
- assert not hook.should_open_pr(info_dict)
- def test_should_open_pr_fail_invalid_url(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- hook._data_path = "asdf"
- assert not hook.should_open_pr(info_dict)
- def test_should_open_pr_fail_closed(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/16"
- assert not hook.should_open_pr(info_dict)
- def test_should_open_pr_fail_assigned(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/17"
- assert not hook.should_open_pr(info_dict)
- def test_should_open_pr_fail_locked(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/18"
- assert not hook.should_open_pr(info_dict)
- def test_should_open_pr_fail_has_pr(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/19"
- assert not hook.should_open_pr(info_dict)
- def test_should_open_pr_success_has_pr_override(open_pr_hook_init_for_sop, info_dict):
- hook = open_pr_hook_init_for_sop
- hook._data_path = "https://github.com/klieret/swe-agent-test-repo/issues/19"
- hook._skip_if_commits_reference_issue = False
- assert hook.should_open_pr(info_dict)
- class RaisesExceptionHook(MainHook):
- def on_instance_start(self, *, index: int, instance: dict[str, Any]):
- msg = "test exception"
- raise ValueError(msg)
- @pytest.fixture()
- def test_script_args():
- return ScriptArguments(
- suffix="",
- environment=EnvironmentArguments(
- image_name="sweagent/swe-agent:latest",
- data_path="https://github.com/klieret/swe-agent-test-repo/issues/1",
- split="dev",
- verbose=True,
- install_environment=True,
- ),
- skip_existing=False,
- agent=AgentArguments(
- model=ModelArguments(
- model_name="instant_empty_submit",
- total_cost_limit=0.0,
- per_instance_cost_limit=3.0,
- temperature=0.0,
- top_p=0.95,
- ),
- config_file=Path("config/default.yaml"),
- ),
- actions=ActionsArguments(open_pr=False, skip_if_commits_reference_issue=True),
- raise_exceptions=True,
- print_config=False,
- )
- @pytest.mark.slow()
- def test_exception_raised(test_script_args: ScriptArguments):
- assert test_script_args.raise_exceptions
- main = Main(test_script_args)
- main.add_hook(RaisesExceptionHook())
- with pytest.raises(ValueError, match="test exception"):
- main.main()
- @pytest.mark.slow()
- class CreateFakeLogFile(MainHook):
- """Testing the skip functionality"""
- def on_init(self, *, args: ScriptArguments, agent: Agent, env: SWEEnv, traj_dir: Path):
- self._traj_dir = traj_dir
- (traj_dir / "args.yaml").write_text("asdf")
- def on_instance_start(self, *, index: int, instance: dict[str, Any]):
- instance_id = instance["instance_id"]
- dct = {
- "info": {"exit_status": "submitted"},
- }
- (self._traj_dir / f"{instance_id}.traj").write_text(json.dumps(dct))
- @pytest.mark.slow()
- def test_existing_corrupted_args(test_script_args: ScriptArguments):
- main = Main(test_script_args)
- main.add_hook(CreateFakeLogFile())
- main.main()
- @pytest.mark.slow()
- def test_main_hook(test_script_args: ScriptArguments):
- main = Main(test_script_args)
- main.add_hook(MainHook())
- main.main()
- @pytest.mark.slow()
- def test_agent_with_hook(test_script_args: ScriptArguments):
- main = Main(test_script_args)
- main.agent.add_hook(AgentHook())
- main.main()
- PERSISTENT_CONTAINER_NAME = "sweagent-test-persistent-container"
- @pytest.fixture()
- def _cleanup_persistent_container():
- yield
- client = docker.from_env()
- container = client.containers.get(PERSISTENT_CONTAINER_NAME)
- container.remove(force=True)
- @pytest.mark.slow()
- @pytest.mark.usefixtures("_cleanup_persistent_container")
- def test_agent_persistent_container(test_script_args: ScriptArguments, capsys):
- test_script_args = dataclasses.replace(
- test_script_args,
- environment=dataclasses.replace(test_script_args.environment, container_name=PERSISTENT_CONTAINER_NAME),
- )
- assert test_script_args.environment.verbose
- main = Main(test_script_args)
- assert main.env.logger.isEnabledFor(logging.DEBUG)
- main.main()
- captured = capsys.readouterr()
- print("---")
- print(captured.out)
- print("---")
- print(captured.err)
- print("---")
- text = captured.out + captured.err
- assert "Trying to clone from non-mirror..." in text
- assert "Falling back to full cloning method" in text
|