static.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. from argparse import ArgumentParser
  5. from pathlib import Path
  6. import yaml
  7. from tqdm.auto import tqdm
  8. try:
  9. from .server import load_content
  10. except ImportError:
  11. from server import load_content
  12. logger = logging.getLogger(__name__)
  13. logging.getLogger("simple_parsing").setLevel(logging.INFO)
  14. TEMPLATE = """
  15. <html>
  16. <head>
  17. <title>Trajectory Viewer</title>
  18. <style>
  19. {style_sheet}
  20. </style>
  21. </head>
  22. <body>
  23. <div class="container">
  24. {file_path_tree}
  25. <h2>Conversation History</h2>
  26. <pre id="fileContent">{file_content}</pre>
  27. </div>
  28. </body>
  29. </html>
  30. """
  31. try:
  32. with open(Path(__file__).parent / "style.css") as infile:
  33. STYLE_SHEET = infile.read()
  34. except Exception as e:
  35. style_file = Path(__file__).parent / "style.css"
  36. logger.error(f"Failed to load style sheet from {style_file}: {e}")
  37. raise e
  38. def _load_file(file_name, gold_patches, test_patches):
  39. try:
  40. role_map = {
  41. "user": "Computer",
  42. "assistant": "SWE-Agent",
  43. "subroutine": "SWE-Agent subroutine",
  44. "default": "Default",
  45. "system": "System",
  46. "demo": "Demonstration",
  47. }
  48. content = load_content(file_name, gold_patches, test_patches)
  49. if "history" in content and isinstance(content["history"], list):
  50. history_content = ""
  51. for index, item in enumerate(content["history"]):
  52. item_content = item.get("content", "").replace("<", "&lt;").replace(">", "&gt;")
  53. if item.get("agent") and item["agent"] != "primary":
  54. role_class = "subroutine"
  55. else:
  56. role_class = item.get("role", "default").lower().replace(" ", "-")
  57. element_id = f"historyItem{index}"
  58. role_name = role_map.get(item.get("role", ""), item.get("role", ""))
  59. history_content += (
  60. f"""<div class="history-item {role_class}" id="{element_id}">"""
  61. f"""<div class="role-bar {role_class}"><strong><span>{role_name}</span></strong></div>"""
  62. f"""<div class="content-container">"""
  63. f"""<pre>{item_content}</pre>"""
  64. f"""</div>"""
  65. f"""<div class="shadow"></div>"""
  66. f"""</div>"""
  67. )
  68. return history_content
  69. else:
  70. return "No history content found."
  71. except Exception as e:
  72. return f"Error loading content. {e}"
  73. def _make_file_path_tree(file_path):
  74. path_parts = file_path.split("/")
  75. relevant_parts = path_parts[-3:]
  76. html_string = '<div class="filepath">\n'
  77. for part in relevant_parts:
  78. html_string += f'<div class="part">{part}</div>\n'
  79. html_string += "</div>"
  80. return html_string
  81. def save_static_viewer(file_path):
  82. if not isinstance(file_path, Path):
  83. file_path = Path(file_path)
  84. data = []
  85. if "args.yaml" in list(map(lambda x: x.name, file_path.parent.iterdir())):
  86. args = yaml.safe_load(Path(file_path.parent / "args.yaml").read_text())
  87. if "environment" in args and "data_path" in args["environment"]:
  88. data_path = Path(__file__).parent.parent / args["environment"]["data_path"]
  89. if data_path.exists():
  90. with open(data_path) as f:
  91. data = json.load(f)
  92. if not isinstance(data, list) or not data or "patch" not in data[0] or "test_patch" not in data[0]:
  93. data = []
  94. gold_patches = {x["instance_id"]: x["patch"] for x in data}
  95. test_patches = {x["instance_id"]: x["test_patch"] for x in data}
  96. content = _load_file(file_path, gold_patches, test_patches)
  97. file_path_tree = _make_file_path_tree(file_path.absolute().as_posix())
  98. icons_path = Path(__file__).parent / "icons"
  99. relative_icons_path = find_relative_path(file_path, icons_path)
  100. style_sheet = STYLE_SHEET.replace("url('icons/", f"url('{relative_icons_path.as_posix()}/").replace(
  101. 'url("icons/',
  102. f'url("{relative_icons_path.as_posix()}/',
  103. )
  104. data = TEMPLATE.format(file_content=content, style_sheet=style_sheet, file_path_tree=file_path_tree)
  105. output_file = file_path.with_suffix(".html")
  106. with open(output_file, "w+") as outfile:
  107. print(data, file=outfile)
  108. logger.info(f"Saved static viewer to {output_file}")
  109. def find_relative_path(from_path, to_path):
  110. # Convert paths to absolute for uniformity
  111. from_path = from_path.resolve()
  112. to_path = to_path.resolve()
  113. if from_path.is_file():
  114. from_path = from_path.parent
  115. if to_path.is_file():
  116. to_path = to_path.parent
  117. if not from_path.is_dir() or not to_path.is_dir():
  118. msg = f"Both from_path and to_path must be directories, but got {from_path} and {to_path}"
  119. raise ValueError(msg)
  120. # Identify the common ancestor and the parts of each path beyond it
  121. common_parts = 0
  122. for from_part, to_part in zip(from_path.parts, to_path.parts):
  123. if from_part != to_part:
  124. break
  125. common_parts += 1
  126. # Calculate the '../' needed to get back from from_path to the common ancestor
  127. back_to_ancestor = [".."] * (len(from_path.parts) - common_parts)
  128. # Direct path from common ancestor to to_path
  129. to_target = to_path.parts[common_parts:]
  130. # Combine to get the relative path
  131. return Path(*back_to_ancestor, *to_target)
  132. def save_all_trajectories(directory):
  133. if not isinstance(directory, Path):
  134. directory = Path(directory)
  135. all_files = list(directory.glob("**/*.traj"))
  136. logger.info(f"Found {len(all_files)} trajectory files in {directory}")
  137. for file_path in tqdm(all_files, desc="Saving static viewers"):
  138. save_static_viewer(file_path)
  139. logger.info(f"Saved static viewers for all trajectories in {args.directory}")
  140. if __name__ == "__main__":
  141. parser = ArgumentParser()
  142. parser.add_argument("directory", type=str, help="Directory containing trajectory files")
  143. args = parser.parse_args()
  144. save_all_trajectories(args.directory)