123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- from __future__ import annotations
- import io
- import json
- from argparse import ArgumentParser
- from pathlib import Path
- from ruamel.yaml import YAML
- from ruamel.yaml.scalarstring import LiteralScalarString as LSS
- DEMO_COMMENT = """# This is a demo file generated from trajectory file:
- # {traj_path}
- # You can use this demo file to replay the actions in the trajectory with run_replay.py.
- # You can edit the content of the actions in this file to modify the replay behavior.
- # NOTICE:
- # Only the actions of the assistant will be replayed.
- # You do not need to modify the observation's contents or any other fields.
- # You can add or remove actions to modify the replay behavior."""
- def convert_to_literal_string(d):
- """
- Convert any multi-line strings to LiteralScalarString
- """
- if isinstance(d, dict):
- for key, value in d.items():
- if isinstance(value, str) and "\n" in value:
- d[key] = LSS(value.replace("\r\n", "\n").replace("\r", "\n"))
- elif isinstance(value, dict):
- convert_to_literal_string(value)
- elif isinstance(d, list):
- for i, item in enumerate(d):
- if isinstance(item, str) and "\n" in item:
- d[i] = LSS(item.replace("\r\n", "\n").replace("\r", "\n"))
- elif isinstance(item, dict):
- convert_to_literal_string(item)
- elif isinstance(d, str) and "\n" in d:
- d = LSS(d.replace("\r\n", "\n").replace("\r", "\n"))
- else:
- raise ValueError(f"Unsupported type: {type(d)}")
- return d
- def save_demo(data, file, traj_path):
- """
- Save a single task instance as a yaml file
- """
- data = convert_to_literal_string(data)
- yaml = YAML()
- yaml.indent(mapping=2, sequence=4, offset=2)
- buffer = io.StringIO()
- yaml.dump(data, buffer)
- content = buffer.getvalue()
- header = DEMO_COMMENT.format(traj_path=traj_path)
- with open(file, "w") as f:
- f.write(f"{header}\n{content}")
- def convert_traj_to_action_demo(traj_path: str, output_file: str = None, include_user: bool = False):
- traj = json.load(open(traj_path))
- history = traj["history"]
- action_traj = list()
- admissible_roles = {"assistant", "user"} if include_user else {"assistant"}
- for step in history:
- if step["role"] in admissible_roles and step.get("agent", "primary") == "primary":
- action_traj.append({k: v for k, v in step.items() if k in {"content", "role"}})
- save_demo(action_traj, output_file, traj_path)
- print(f"Saved demo to {output_file}")
- def main(traj_path: str, output_dir: str = None, suffix: str = "", overwrite: bool = False, include_user: bool = False):
- filename = (
- "/".join([Path(traj_path).parent.name + suffix, Path(traj_path).name.rsplit(".traj", 1)[0]]) + ".demo.yaml"
- )
- output_file = Path(output_dir) / filename
- if output_file.exists() and not overwrite:
- raise FileExistsError(f"Output file already exists: {output_file}")
- output_file.parent.mkdir(parents=True, exist_ok=True)
- convert_traj_to_action_demo(traj_path, output_file, include_user)
- def string2bool(s):
- if s.lower() in {"true", "1"}:
- return True
- elif s.lower() in {"false", "0"}:
- return False
- else:
- raise ValueError(f"Invalid boolean string: {s}")
- if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument("traj_path", type=str, help="Path to trajectory file")
- parser.add_argument("--output_dir", type=str, help="Output directory for action demos", default="./demos")
- parser.add_argument("--suffix", type=str, help="Suffix for the output file", default="")
- parser.add_argument("--overwrite", type=string2bool, help="Overwrite existing files", default=False, nargs="?")
- parser.add_argument(
- "--include_user", type=string2bool, help="Include user responses (computer)", default=False, nargs="?"
- )
- args = parser.parse_args()
- main(**vars(args))
|