run_replay.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. """Replay a trajectory"""
  2. import json
  3. import os
  4. import yaml
  5. from argparse import ArgumentParser
  6. from typing import Any, Dict, List
  7. import run as runscript
  8. def process_single_traj(
  9. traj_path: str,
  10. config_file: str,
  11. data_path: str,
  12. suffix: str,
  13. *,
  14. forward_args: List[str],
  15. ):
  16. """
  17. Args:
  18. traj_path (str): _description_
  19. config_file (str): _description_
  20. data_path (str): _description_
  21. suffix (str): _description_
  22. forward_args (List[str]): Passed to run.py
  23. Raises:
  24. ValueError: Incorrect paths or other config issue
  25. Returns:
  26. None
  27. """
  28. replay_action_trajs_path = "temp_replay.jsonl"
  29. # Open trajectory file, extract responses as actions
  30. if traj_path.endswith(".yaml"):
  31. traj_data = dict()
  32. with open(traj_path, "r") as f:
  33. traj_data["history"] = yaml.safe_load(f)
  34. else:
  35. traj_data = json.load(open(traj_path, "r"))
  36. actions = [x["content"] for x in traj_data["history"] if x["role"] == "assistant"]
  37. instance_id = traj_path.split("/")[-1].split(".")[0]
  38. with open(replay_action_trajs_path, "w") as f:
  39. print(json.dumps({instance_id: actions}), file=f, end="\n", flush=True)
  40. # Get data_path from args.yaml
  41. if data_path is None:
  42. args_path = os.path.join(os.path.dirname(traj_path), "args.yaml")
  43. args = yaml.safe_load(open(args_path))
  44. data_path = args["environment"]["data_path"]
  45. # Identify the relevant task instance and create it
  46. def create_task_instances_tmp_file(data: List[Dict[str, Any]]) -> str:
  47. """Helper function to create a temporary file to write task instances to.
  48. Returns path to the temporary file.
  49. """
  50. data = [d for d in data if d["instance_id"] == instance_id]
  51. tmp_path = instance_id + ".jsonl"
  52. with open(tmp_path, "w") as f:
  53. for d in data:
  54. print(json.dumps(d), file=f, end="\n", flush=True)
  55. return tmp_path
  56. is_other = False
  57. if data_path.endswith(".jsonl"):
  58. replay_task_instances_path = create_task_instances_tmp_file(
  59. [json.loads(x) for x in open(data_path, "r").readlines()]
  60. )
  61. elif data_path.endswith(".json"):
  62. replay_task_instances_path = create_task_instances_tmp_file(json.load(open(data_path)))
  63. else:
  64. # Assume data_path is a github url or local url
  65. is_other = True
  66. replay_task_instances_path = data_path
  67. # Call run.py via subprocess
  68. run_args = [
  69. "--config_file",
  70. config_file,
  71. "--data_path",
  72. replay_task_instances_path,
  73. "--install_environment",
  74. "True",
  75. "--model_name",
  76. "replay",
  77. "--replay_path",
  78. replay_action_trajs_path,
  79. *forward_args,
  80. ]
  81. if is_other:
  82. # Not sure if this only applies to github urls for data_path
  83. run_args.extend(["--skip_existing", "False"])
  84. if suffix is not None:
  85. run_args.extend(["--suffix", suffix])
  86. script_args = runscript.get_args(run_args)
  87. runscript.main(script_args)
  88. os.remove(replay_action_trajs_path)
  89. if not is_other:
  90. os.remove(replay_task_instances_path)
  91. def main(
  92. traj_path: str,
  93. config_file: str,
  94. data_path: str,
  95. suffix: str,
  96. *,
  97. forward_args: List[str],
  98. ):
  99. process_single_traj(traj_path, config_file, data_path, suffix, forward_args=forward_args)
  100. def get_args(args=None):
  101. parser = ArgumentParser(description=__doc__)
  102. parser.add_argument("--traj_path", help="Path to trajectory to replay", default=None)
  103. parser.add_argument("--config_file", help="Path to template", required=True)
  104. parser.add_argument(
  105. "--data_path",
  106. help="(Optional) Path to data file containing task instances ref'ed by replay trajectories",
  107. default=None,
  108. )
  109. parser.add_argument(
  110. "--suffix",
  111. help="(Optional) Suffix argument appended to end of traj path",
  112. default=None,
  113. )
  114. args, remaining_args = parser.parse_known_args(args=args)
  115. return args, remaining_args
  116. if __name__ == "__main__":
  117. args, remaining_args = get_args()
  118. main(**vars(args), forward_args=remaining_args)