test_replay_ctf.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from __future__ import annotations
  2. import copy
  3. from typing import Any
  4. import pytest
  5. from run_replay import get_args, main
  6. LOG_NOT_CONTAINS_DEFAULT = [
  7. "Traceback",
  8. "Exception",
  9. "socket.gaierror",
  10. ]
  11. class ReplayRunValidator:
  12. def __init__(
  13. self,
  14. *,
  15. log_contains: list[str] | None = None,
  16. log_not_contains: list[str] | None = None,
  17. expected_traj: str | None = None,
  18. ):
  19. if log_contains is None:
  20. log_contains = []
  21. if log_not_contains is None:
  22. log_not_contains = copy.copy(LOG_NOT_CONTAINS_DEFAULT)
  23. self._log_contains = log_contains
  24. self._log_not_contains = log_not_contains + LOG_NOT_CONTAINS_DEFAULT
  25. self._expected_traj = expected_traj
  26. def _sanitize_observation(self, observation: str) -> str:
  27. # exclude everything that looks like a path
  28. return "\n".join(line for line in observation.splitlines() if "/" not in line).strip()
  29. def _sanitize_traj(self, traj: dict[str, Any]) -> dict[str, Any]:
  30. traj = copy.deepcopy(traj)
  31. # can restore later
  32. observations = [self._sanitize_observation(t["observation"]) for t in traj["trajectory"]]
  33. return {"trajectory": observations}
  34. def __call__(self, stdout: str, traj: str | None = None) -> None:
  35. for log in self._log_contains:
  36. assert log in stdout, log
  37. for log in self._log_not_contains:
  38. assert log not in stdout, log
  39. if self._expected_traj is not None:
  40. assert traj == self._expected_traj
  41. _REPLAY_TESTS = {
  42. "pwn/warmup.traj": ReplayRunValidator(
  43. log_contains=[
  44. "File updated.",
  45. "Opening connection to pwn.chal.csaw.io on port 8000: Done",
  46. "Receiving all data",
  47. ],
  48. ),
  49. "forensics/flash.traj": ReplayRunValidator(
  50. log_contains=["the black flag waved night and day from the"],
  51. ),
  52. # "web/i_got_id_demo.traj": r"4365",
  53. "misc/networking_1.traj": ReplayRunValidator(
  54. log_contains=["Password: "],
  55. ),
  56. }
  57. @pytest.mark.slow
  58. @pytest.mark.ctf
  59. @pytest.mark.parametrize(
  60. "traj_rel_path",
  61. ["pwn/warmup.traj", "forensics/flash.traj", "web/i_got_id_demo.traj", "rev/rock.traj", "misc/networking_1.traj"],
  62. )
  63. def test_ctf_traj_replay(test_ctf_trajectories_path, traj_rel_path, ctf_data_path, capsys):
  64. # if sys.platform == "darwin" and traj_rel_path in ["pwn/warmup.traj", "rev/rock.traj"]:
  65. # pytest.skip("Skipping test on macOS")
  66. traj_path = test_ctf_trajectories_path / traj_rel_path
  67. challenge_dir = ctf_data_path / traj_rel_path.removesuffix(".traj")
  68. assert challenge_dir.is_dir()
  69. data_path = challenge_dir / "challenge.json"
  70. assert data_path.is_file()
  71. assert traj_path.is_file()
  72. args = [
  73. "--traj_path",
  74. str(traj_path),
  75. "--data_path",
  76. str(data_path),
  77. "--repo_path",
  78. str(challenge_dir),
  79. "--config_file",
  80. "config/default_ctf.yaml",
  81. "--raise_exceptions",
  82. "--noprint_config",
  83. "--ctf",
  84. ]
  85. args, remaining_args = get_args(args)
  86. main(**vars(args), forward_args=remaining_args)
  87. captured = capsys.readouterr()
  88. if traj_rel_path in _REPLAY_TESTS:
  89. _REPLAY_TESTS[traj_rel_path](stdout=captured.out)