run_replay.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. """Replay a trajectory"""
  2. from __future__ import annotations
  3. import json
  4. import os
  5. from argparse import ArgumentParser
  6. from pathlib import Path
  7. from typing import Any
  8. import yaml
  9. import run as runscript
  10. def process_single_traj(traj_path: str, config_file: str, data_path: str, suffix: str, *, forward_args: list[str]):
  11. """
  12. Args:
  13. traj_path (str): _description_
  14. config_file (str): _description_
  15. data_path (str): _description_
  16. suffix (str): _description_
  17. forward_args (List[str]): Passed to run.py
  18. Raises:
  19. ValueError: Incorrect paths or other config issue
  20. Returns:
  21. None
  22. """
  23. replay_action_trajs_path = "temp_replay.jsonl"
  24. # Open trajectory file, extract responses as actions
  25. if traj_path.endswith(".yaml"):
  26. traj_data = dict()
  27. with open(traj_path) as f:
  28. traj_data["history"] = yaml.safe_load(f)
  29. else:
  30. with open(traj_path) as file:
  31. traj_data = json.load(file)
  32. actions = [x["content"] for x in traj_data["history"] if x["role"] == "assistant"]
  33. instance_id = traj_path.split("/")[-1].split(".")[0]
  34. with open(replay_action_trajs_path, "w") as f:
  35. print(json.dumps({instance_id: actions}), file=f, end="\n", flush=True)
  36. # Get data_path from args.yaml
  37. if data_path is None:
  38. args_path = os.path.join(os.path.dirname(traj_path), "args.yaml")
  39. with open(args_path) as f:
  40. args = yaml.safe_load(f)
  41. data_path = args["environment"]["data_path"]
  42. # Identify the relevant task instance and create it
  43. def create_task_instances_tmp_file(data: list[dict[str, Any]]) -> str:
  44. """Helper function to create a temporary file to write task instances to.
  45. Returns path to the temporary file.
  46. """
  47. data = [d for d in data if d["instance_id"] == instance_id]
  48. tmp_path = instance_id + ".jsonl"
  49. with open(tmp_path, "w") as f:
  50. for d in data:
  51. print(json.dumps(d), file=f, end="\n", flush=True)
  52. return tmp_path
  53. is_other = False
  54. if data_path.endswith(".jsonl"):
  55. replay_task_instances_path = create_task_instances_tmp_file(
  56. [json.loads(x) for x in Path(data_path).read_text().splitlines(keepends=True)],
  57. )
  58. elif data_path.endswith(".json"):
  59. with open(data_path) as file:
  60. data = json.load(file)
  61. replay_task_instances_path = create_task_instances_tmp_file(data)
  62. else:
  63. # Assume data_path is a github url or local url
  64. is_other = True
  65. replay_task_instances_path = data_path
  66. # Call run.py via subprocess
  67. run_args = [
  68. "--config_file",
  69. config_file,
  70. "--data_path",
  71. replay_task_instances_path,
  72. "--install_environment",
  73. "True",
  74. "--model_name",
  75. "replay",
  76. "--replay_path",
  77. replay_action_trajs_path,
  78. *forward_args,
  79. ]
  80. if is_other:
  81. # Not sure if this only applies to github urls for data_path
  82. run_args.extend(["--skip_existing", "False"])
  83. if suffix is not None:
  84. run_args.extend(["--suffix", suffix])
  85. script_args = runscript.get_args(run_args)
  86. runscript.main(script_args)
  87. os.remove(replay_action_trajs_path)
  88. if not is_other:
  89. os.remove(replay_task_instances_path)
  90. def main(
  91. traj_path: str,
  92. config_file: str,
  93. data_path: str,
  94. suffix: str,
  95. *,
  96. forward_args: list[str],
  97. ):
  98. process_single_traj(traj_path, config_file, data_path, suffix, forward_args=forward_args)
  99. def get_args(args=None):
  100. parser = ArgumentParser(description=__doc__)
  101. parser.add_argument("--traj_path", help="Path to trajectory to replay", required=True)
  102. parser.add_argument("--config_file", help="Path to template", required=True)
  103. parser.add_argument(
  104. "--data_path",
  105. help="(Optional) Path to data file containing task instances ref'ed by replay trajectories",
  106. default=None,
  107. )
  108. parser.add_argument("--suffix", help="(Optional) Suffix argument appended to end of traj path", default=None)
  109. args, remaining_args = parser.parse_known_args(args=args)
  110. return args, remaining_args
  111. if __name__ == "__main__":
  112. args, remaining_args = get_args()
  113. main(**vars(args), forward_args=remaining_args)