static.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import json
  2. import logging
  3. import yaml
  4. from pathlib import Path
  5. from tqdm.auto import tqdm
  6. from argparse import ArgumentParser
  7. try:
  8. from .server import load_content
  9. except ImportError:
  10. from server import load_content
  11. logger = logging.getLogger(__name__)
  12. logging.getLogger("simple_parsing").setLevel(logging.INFO)
  13. TEMPLATE = """
  14. <html>
  15. <head>
  16. <title>Trajectory Viewer</title>
  17. <style>
  18. {style_sheet}
  19. </style>
  20. </head>
  21. <body>
  22. <div class="container">
  23. {file_path_tree}
  24. <h2>Conversation History</h2>
  25. <pre id="fileContent">{file_content}</pre>
  26. </div>
  27. </body>
  28. </html>
  29. """
  30. try:
  31. with open(Path(__file__).parent / 'style.css', 'r') as infile:
  32. STYLE_SHEET = infile.read()
  33. except Exception as e:
  34. style_file = Path(__file__).parent / 'style.css'
  35. logger.error(f"Failed to load style sheet from {style_file}: {e}")
  36. raise e
  37. def _load_file(file_name, gold_patches, test_patches):
  38. try:
  39. role_map = {
  40. "user": "Computer",
  41. "assistant": "SWE-Agent",
  42. "subroutine": "SWE-Agent subroutine",
  43. "default": "Default",
  44. "system": "System",
  45. }
  46. content = load_content(file_name, gold_patches, test_patches)
  47. if 'history' in content and isinstance(content['history'], list):
  48. history_content = ""
  49. for index, item in enumerate(content['history']):
  50. item_content = item.get('content', '').replace('<', '&lt;').replace('>', '&gt;')
  51. if item.get('agent') and item['agent'] != "primary":
  52. role_class = "subroutine"
  53. else:
  54. role_class = item.get('role', 'default').lower().replace(' ', '-')
  55. element_id = f"historyItem{index}"
  56. role_name = role_map.get(item.get('role', ''), item.get('role', ''))
  57. history_content += (
  58. f'''<div class="history-item {role_class}" id="{element_id}">'''
  59. f'''<div class="role-bar {role_class}"><strong><span>{role_name}</span></strong></div>'''
  60. f'''<div class="content-container">'''
  61. f'''<pre>{item_content}</pre>'''
  62. f'''</div>'''
  63. f'''<div class="shadow"></div>'''
  64. f'''</div>'''
  65. )
  66. return history_content
  67. else:
  68. return 'No history content found.'
  69. except Exception as e:
  70. return f"Error loading content. {e}"
  71. def _make_file_path_tree(file_path):
  72. path_parts = file_path.split('/')
  73. relevant_parts = path_parts[-3:]
  74. html_string = '<div class="filepath">\n'
  75. for part in relevant_parts:
  76. html_string += f'<div class="part">{part}</div>\n'
  77. html_string += '</div>'
  78. return html_string
  79. def save_static_viewer(file_path):
  80. if not isinstance(file_path, Path):
  81. file_path = Path(file_path)
  82. data = []
  83. if "args.yaml" in list(map(lambda x: x.name, file_path.parent.iterdir())):
  84. args = yaml.safe_load(open(file_path.parent / "args.yaml", "r"))
  85. if "environment" in args and "data_path" in args["environment"]:
  86. data_path = Path(__file__).parent.parent / args["environment"]["data_path"]
  87. if data_path.exists():
  88. data = json.load(open(data_path, "r"))
  89. if (
  90. not isinstance(data, list) or
  91. not data or
  92. 'patch' not in data[0] or
  93. 'test_patch' not in data[0]
  94. ):
  95. data = []
  96. gold_patches = {x["instance_id"]: x["patch"] for x in data}
  97. test_patches = {x["instance_id"]: x["test_patch"] for x in data}
  98. content = _load_file(file_path, gold_patches, test_patches)
  99. file_path_tree = _make_file_path_tree(file_path.absolute().as_posix())
  100. icons_path = Path(__file__).parent / 'icons'
  101. relative_icons_path = find_relative_path(file_path, icons_path)
  102. style_sheet = STYLE_SHEET.replace(
  103. "url('icons/",
  104. f"url('{relative_icons_path.as_posix()}/"
  105. ).replace(
  106. 'url("icons/',
  107. f'url("{relative_icons_path.as_posix()}/'
  108. )
  109. data = TEMPLATE.format(file_content=content, style_sheet=style_sheet, file_path_tree=file_path_tree)
  110. output_file = file_path.with_suffix('.html')
  111. with open(output_file, 'w+') as outfile:
  112. print(data, file=outfile)
  113. logger.info(f"Saved static viewer to {output_file}")
  114. def find_relative_path(from_path, to_path):
  115. # Convert paths to absolute for uniformity
  116. from_path = from_path.resolve()
  117. to_path = to_path.resolve()
  118. if from_path.is_file():
  119. from_path = from_path.parent
  120. if to_path.is_file():
  121. to_path = to_path.parent
  122. if not from_path.is_dir() or not to_path.is_dir():
  123. raise ValueError(f"Both from_path and to_path must be directories, but got {from_path} and {to_path}")
  124. # Identify the common ancestor and the parts of each path beyond it
  125. common_parts = 0
  126. for from_part, to_part in zip(from_path.parts, to_path.parts):
  127. if from_part != to_part:
  128. break
  129. common_parts += 1
  130. # Calculate the '../' needed to get back from from_path to the common ancestor
  131. back_to_ancestor = ['..'] * (len(from_path.parts) - common_parts)
  132. # Direct path from common ancestor to to_path
  133. to_target = to_path.parts[common_parts:]
  134. # Combine to get the relative path
  135. relative_path = Path(*back_to_ancestor, *to_target)
  136. return relative_path
  137. def save_all_trajectories(directory):
  138. if not isinstance(directory, Path):
  139. directory = Path(directory)
  140. all_files = list(directory.glob('**/*.traj'))
  141. logger.info(f"Found {len(all_files)} trajectory files in {directory}")
  142. for file_path in tqdm(all_files, desc="Saving static viewers"):
  143. save_static_viewer(file_path)
  144. logger.info(f"Saved static viewers for all trajectories in {args.directory}")
  145. if __name__ == "__main__":
  146. parser = ArgumentParser()
  147. parser.add_argument("directory", type=str, help="Directory containing trajectory files")
  148. args = parser.parse_args()
  149. save_all_trajectories(args.directory)