run_replay.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import json
  2. import os
  3. import subprocess
  4. import yaml
  5. from argparse import ArgumentParser
  6. def process_synthetic_trajs(action_trajs_path: str, config_file: str, suffix: str):
  7. # Load action trajectories, task instances
  8. action_trajs = [json.loads(x) for x in open(action_trajs_path, "r").readlines()]
  9. task_instances = [x["task_instance"] for x in action_trajs]
  10. file_name = action_trajs_path.rsplit("/", 1)[-1]
  11. # Temporary file names
  12. replay_action_trajs_path = "temp_actions.jsonl"
  13. replay_task_instances_path = file_name
  14. # Write task_instances to file for data_path
  15. with open(replay_task_instances_path, "w") as f:
  16. for t in task_instances:
  17. print(json.dumps(t), file=f, end="\n", flush=True)
  18. # Write action trajectories to a file
  19. with open(replay_action_trajs_path, "w") as f:
  20. for t in action_trajs:
  21. print(
  22. json.dumps({t["task_instance"]["instance_id"]: t["actions"]}),
  23. file=f,
  24. end="\n",
  25. flush=True,
  26. )
  27. # Call run.py via subprocess
  28. command = [
  29. "python",
  30. "run.py",
  31. "--config_file", config_file,
  32. "--data_path", replay_task_instances_path,
  33. "--install_environment", "True",
  34. "--model_name", "replay",
  35. "--replay_path", replay_action_trajs_path
  36. ]
  37. if suffix is not None:
  38. command.extend(["--suffix", suffix])
  39. subprocess.run(command)
  40. os.remove(replay_action_trajs_path)
  41. os.remove(replay_task_instances_path)
  42. def process_single_traj(traj_path: str, config_file: str, data_path: str, suffix: str):
  43. replay_action_trajs_path = "temp_replay.jsonl"
  44. # Open trajectory file, extract responses as actions
  45. if traj_path.endswith(".yaml"):
  46. traj_data = dict()
  47. with open(traj_path, "r") as f:
  48. traj_data["history"] = yaml.safe_load(f)
  49. else:
  50. traj_data = json.load(open(traj_path, "r"))
  51. actions = [x["content"] for x in traj_data["history"] if x["role"] == "assistant"]
  52. instance_id = traj_path.split("/")[-1].split(".")[0]
  53. with open(replay_action_trajs_path, "w") as f:
  54. print(
  55. json.dumps({instance_id: actions}),
  56. file=f,
  57. end="\n",
  58. flush=True
  59. )
  60. replay_task_instances_path = instance_id + ".jsonl"
  61. # Get data_path from args.yaml
  62. if data_path is None:
  63. args_path = os.path.join(
  64. os.path.dirname(traj_path),
  65. "args.yaml"
  66. )
  67. args = yaml.safe_load(open(args_path))
  68. data_path = args['environment']['data_path']
  69. # Identify the relevant task instance and create it
  70. data = None
  71. if data_path.endswith(".jsonl"):
  72. data = [json.loads(x) for x in open(data_path, "r").readlines()]
  73. elif data_path.endswith(".json"):
  74. data = json.load(open(data_path))
  75. else:
  76. raise ValueError("--data_path must be a .json or .jsonl")
  77. data = [d for d in data if d["instance_id"] == instance_id]
  78. with open(replay_task_instances_path, "w") as f:
  79. for d in data:
  80. print(json.dumps(d), file=f, end="\n", flush=True)
  81. # Call run.py via subprocess
  82. command = [
  83. "python",
  84. "run.py",
  85. "--config_file", config_file,
  86. "--data_path", replay_task_instances_path,
  87. "--install_environment", "True",
  88. "--model_name", "replay",
  89. "--replay_path", replay_action_trajs_path,
  90. ]
  91. if suffix is not None:
  92. command.extend(["--suffix", suffix])
  93. subprocess.run(command)
  94. os.remove(replay_action_trajs_path)
  95. os.remove(replay_task_instances_path)
  96. def main(
  97. action_trajs_path: str,
  98. traj_path: str,
  99. config_file: str,
  100. data_path: str,
  101. suffix: str,
  102. ):
  103. if action_trajs_path is not None:
  104. process_synthetic_trajs(action_trajs_path, config_file, suffix)
  105. elif traj_path is not None:
  106. process_single_traj(traj_path, config_file, data_path, suffix)
  107. else:
  108. print(
  109. "No replays generated.\n"
  110. "You must either provide one of the following. Either...\n"
  111. "\t* --action_trajs_path for replaying synthetic trajectories\n"
  112. "\t* --traj_path for replaying SWE-agent style trajectories (from ./trajectories folder)\n"
  113. )
  114. if __name__ == "__main__":
  115. parser = ArgumentParser()
  116. parser.add_argument("--action_trajs_path", help="Path to action trajectories to replay", default=None)
  117. parser.add_argument("--traj_path", help="Path to trajectory to replay", default=None)
  118. parser.add_argument("--config_file", help="Path to template", required=True)
  119. parser.add_argument("--data_path", help="(Optional) Path to data file containing task instances ref'ed by replay trajectories", default=None)
  120. parser.add_argument("--suffix", help="(Optional) Suffix argument appended to end of traj path", default=None)
  121. args = parser.parse_args()
  122. main(**vars(args))