convert_traj_to_demo.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from __future__ import annotations
  2. import io
  3. import json
  4. from argparse import ArgumentParser
  5. from pathlib import Path
  6. from ruamel.yaml import YAML
  7. from ruamel.yaml.scalarstring import LiteralScalarString as LSS
  8. DEMO_COMMENT = """# This is a demo file generated from trajectory file:
  9. # {traj_path}
  10. # You can use this demo file to replay the actions in the trajectory with run_replay.py.
  11. # You can edit the content of the actions in this file to modify the replay behavior.
  12. # NOTICE:
  13. # Only the actions of the assistant will be replayed.
  14. # You do not need to modify the observation's contents or any other fields.
  15. # You can add or remove actions to modify the replay behavior."""
  16. def convert_to_literal_string(d):
  17. """
  18. Convert any multi-line strings to LiteralScalarString
  19. """
  20. if isinstance(d, dict):
  21. for key, value in d.items():
  22. if isinstance(value, str) and "\n" in value:
  23. d[key] = LSS(value.replace("\r\n", "\n").replace("\r", "\n"))
  24. elif isinstance(value, dict):
  25. convert_to_literal_string(value)
  26. elif isinstance(d, list):
  27. for i, item in enumerate(d):
  28. if isinstance(item, str) and "\n" in item:
  29. d[i] = LSS(item.replace("\r\n", "\n").replace("\r", "\n"))
  30. elif isinstance(item, dict):
  31. convert_to_literal_string(item)
  32. elif isinstance(d, str) and "\n" in d:
  33. d = LSS(d.replace("\r\n", "\n").replace("\r", "\n"))
  34. else:
  35. raise ValueError(f"Unsupported type: {type(d)}")
  36. return d
  37. def save_demo(data, file, traj_path):
  38. """
  39. Save a single task instance as a yaml file
  40. """
  41. data = convert_to_literal_string(data)
  42. yaml = YAML()
  43. yaml.indent(mapping=2, sequence=4, offset=2)
  44. buffer = io.StringIO()
  45. yaml.dump(data, buffer)
  46. content = buffer.getvalue()
  47. header = DEMO_COMMENT.format(traj_path=traj_path)
  48. with open(file, "w") as f:
  49. f.write(f"{header}\n{content}")
  50. def convert_traj_to_action_demo(traj_path: str, output_file: str = None, include_user: bool = False):
  51. traj = json.load(open(traj_path))
  52. history = traj["history"]
  53. action_traj = list()
  54. admissible_roles = {"assistant", "user"} if include_user else {"assistant"}
  55. for step in history:
  56. if step["role"] in admissible_roles and step.get("agent", "primary") == "primary":
  57. action_traj.append({k: v for k, v in step.items() if k in {"content", "role"}})
  58. save_demo(action_traj, output_file, traj_path)
  59. print(f"Saved demo to {output_file}")
  60. def main(traj_path: str, output_dir: str = None, suffix: str = "", overwrite: bool = False, include_user: bool = False):
  61. filename = (
  62. "/".join([Path(traj_path).parent.name + suffix, Path(traj_path).name.rsplit(".traj", 1)[0]]) + ".demo.yaml"
  63. )
  64. output_file = Path(output_dir) / filename
  65. if output_file.exists() and not overwrite:
  66. raise FileExistsError(f"Output file already exists: {output_file}")
  67. output_file.parent.mkdir(parents=True, exist_ok=True)
  68. convert_traj_to_action_demo(traj_path, output_file, include_user)
  69. def string2bool(s):
  70. if s.lower() in {"true", "1"}:
  71. return True
  72. elif s.lower() in {"false", "0"}:
  73. return False
  74. else:
  75. raise ValueError(f"Invalid boolean string: {s}")
  76. if __name__ == "__main__":
  77. parser = ArgumentParser()
  78. parser.add_argument("traj_path", type=str, help="Path to trajectory file")
  79. parser.add_argument("--output_dir", type=str, help="Output directory for action demos", default="./demos")
  80. parser.add_argument("--suffix", type=str, help="Suffix for the output file", default="")
  81. parser.add_argument("--overwrite", type=string2bool, help="Overwrite existing files", default=False, nargs="?")
  82. parser.add_argument(
  83. "--include_user", type=string2bool, help="Include user responses (computer)", default=False, nargs="?"
  84. )
  85. args = parser.parse_args()
  86. main(**vars(args))