test_replay.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from __future__ import annotations
  2. import os
  3. import subprocess
  4. from pathlib import Path
  5. from unittest import mock
  6. import pytest
  7. from run_replay import get_args, main
  8. from sweagent import CONFIG_DIR
  9. @pytest.fixture
  10. def swe_agent_test_repo_clone(tmp_path):
  11. local_repo_path = tmp_path / "test-repo"
  12. clone_cmd = ["git", "clone", "https://github.com/swe-agent/test-repo", local_repo_path]
  13. subprocess.run(clone_cmd, check=True)
  14. return local_repo_path
  15. @pytest.fixture
  16. def swe_agent_test_repo_traj(test_trajectories_path) -> Path:
  17. p = (
  18. test_trajectories_path
  19. / "gpt4__swe-agent-test-repo__default_from_url__t-0.00__p-0.95__c-3.00__install-1"
  20. / "6e44b9__sweagenttestrepo-1c2844.traj"
  21. )
  22. assert p.is_file()
  23. return p
  24. @pytest.fixture
  25. def swe_agent_test_repo_local_problem_stmt(swe_agent_test_repo_clone) -> Path:
  26. problem_stmt = swe_agent_test_repo_clone / "problem_statements" / "1.md"
  27. assert problem_stmt.is_file()
  28. return problem_stmt
  29. @pytest.mark.slow
  30. @pytest.mark.parametrize("problem_statement_source", ["github", "local"])
  31. def test_model_replay_github_repo(
  32. tmpdir,
  33. swe_agent_test_repo_traj,
  34. problem_statement_source,
  35. swe_agent_test_repo_local_problem_stmt,
  36. ):
  37. if problem_statement_source == "github":
  38. data_path = "https://github.com/swe-agent/test-repo/issues/1"
  39. elif problem_statement_source == "local":
  40. data_path = str(swe_agent_test_repo_local_problem_stmt)
  41. args = [
  42. "--traj_path",
  43. str(swe_agent_test_repo_traj.resolve()),
  44. "--data_path",
  45. data_path,
  46. "--config_file",
  47. str(CONFIG_DIR / "default_from_url.yaml"),
  48. "--raise_exceptions",
  49. ]
  50. if problem_statement_source == "local":
  51. args.extend(["--repo_path", "https://github.com/swe-agent/test-repo/"])
  52. args, remaining_args = get_args(args)
  53. with tmpdir.as_cwd():
  54. # Test that we can run run.py also independently from repo dir
  55. main(**vars(args), forward_args=remaining_args)
  56. @pytest.mark.slow
  57. def test_model_replay_from_json(test_trajectories_path, test_data_sources_path):
  58. traj_path = (
  59. test_trajectories_path
  60. / "gpt4__swe-bench-dev-easy_first_only__default__t-0.00__p-0.95__c-3.00__install-1"
  61. / "pydicom__pydicom-1458.traj"
  62. )
  63. assert traj_path.is_file()
  64. data_path = test_data_sources_path / "swe-bench-dev-easy_first_only.json"
  65. assert data_path.is_file()
  66. args = [
  67. "--traj_path",
  68. str(traj_path),
  69. "--data_path",
  70. str(data_path),
  71. "--config_file",
  72. "config/default.yaml",
  73. "--raise_exceptions",
  74. ]
  75. args, remaining_args = get_args(args)
  76. main(**vars(args), forward_args=remaining_args)
  77. def test_run_cli_help():
  78. args = [
  79. "python",
  80. "run_replay.py",
  81. "--help",
  82. ]
  83. subprocess.run(args, check=True)
  84. @pytest.mark.slow
  85. @pytest.mark.parametrize("problem_statement_source", ["github", "local"])
  86. def test_model_replay_local_repo(swe_agent_test_repo_clone, swe_agent_test_repo_traj, problem_statement_source):
  87. local_repo_path = swe_agent_test_repo_clone
  88. if problem_statement_source == "github":
  89. problem_statement_path = "https://github.com/swe-agent/test-repo/issues/1"
  90. elif problem_statement_source == "local":
  91. problem_statement_path = local_repo_path / "problem_statements" / "1.md"
  92. assert problem_statement_path.is_file()
  93. else:
  94. raise ValueError(problem_statement_source)
  95. run_cmd = [
  96. "--traj_path",
  97. str(swe_agent_test_repo_traj),
  98. "--repo_path",
  99. str(local_repo_path),
  100. "--config_file",
  101. "config/default_from_url.yaml",
  102. "--data_path",
  103. str(problem_statement_path),
  104. "--apply_patch",
  105. "--raise_exceptions",
  106. ]
  107. print(run_cmd)
  108. args, remaining_args = get_args(run_cmd)
  109. main(**vars(args), forward_args=remaining_args)
  110. solution = (swe_agent_test_repo_traj.parent / "solution_missing_colon.py").read_text().strip()
  111. solution_retrieved = (local_repo_path / "tests" / "missing_colon.py").read_text().strip()
  112. assert solution == solution_retrieved
  113. def test_exception_replay_local_dirty(swe_agent_test_repo_clone, swe_agent_test_repo_traj):
  114. """Test that swe-agent refuses to work if the local repo is dirty"""
  115. problem_statement_path = swe_agent_test_repo_clone / "problem_statements" / "1.md"
  116. test_file = swe_agent_test_repo_clone / "tests" / "missing_colon.py"
  117. assert test_file.is_file()
  118. test_file.write_text(test_file.read_text().replace("division", "division_function"))
  119. run_cmd = [
  120. "--traj_path",
  121. str(swe_agent_test_repo_traj),
  122. "--repo_path",
  123. str(swe_agent_test_repo_clone),
  124. "--config_file",
  125. "config/default_from_url.yaml",
  126. "--data_path",
  127. str(problem_statement_path),
  128. "--apply_patch",
  129. "--raise_exceptions",
  130. ]
  131. args, remaining_args = get_args(run_cmd)
  132. # In the code, we exclude testing from this check because of tests hosted in the repo
  133. # so here we pretend we're not in a test
  134. modified_environ = {k: v for k, v in os.environ.items() if k != "PYTEST_CURRENT_TEST"}
  135. with mock.patch.dict(os.environ, modified_environ, clear=True):
  136. with pytest.raises(ValueError, match=".*dirty.*"):
  137. main(**vars(args), forward_args=remaining_args)