convert_traj_to_demo.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import json
  2. import io
  3. from ruamel.yaml import YAML
  4. from ruamel.yaml.scalarstring import LiteralScalarString as LSS
  5. from pathlib import Path
  6. from argparse import ArgumentParser
  7. DEMO_COMMENT = """# This is a demo file generated from trajectory file:
  8. # {traj_path}
  9. # You can use this demo file to replay the actions in the trajectory with run_replay.py.
  10. # You can edit the content of the actions in this file to modify the replay behavior.
  11. # NOTICE:
  12. # Only the actions of the assistant will be replayed.
  13. # You do not need to modify the observation's contents or any other fields.
  14. # You can add or remove actions to modify the replay behavior."""
  15. def convert_to_literal_string(d):
  16. """
  17. Convert any multi-line strings to LiteralScalarString
  18. """
  19. if isinstance(d, dict):
  20. for key, value in d.items():
  21. if isinstance(value, str) and '\n' in value:
  22. d[key] = LSS(value.replace('\r\n', '\n').replace('\r', '\n'))
  23. elif isinstance(value, dict):
  24. convert_to_literal_string(value)
  25. elif isinstance(d, list):
  26. for i, item in enumerate(d):
  27. if isinstance(item, str) and '\n' in item:
  28. d[i] = LSS(item.replace('\r\n', '\n').replace('\r', '\n'))
  29. elif isinstance(item, dict):
  30. convert_to_literal_string(item)
  31. elif isinstance(d, str) and '\n' in d:
  32. d = LSS(d.replace('\r\n', '\n').replace('\r', '\n'))
  33. else:
  34. raise ValueError(f"Unsupported type: {type(d)}")
  35. return d
  36. def save_demo(data, file, traj_path):
  37. """
  38. Save a single task instance as a yaml file
  39. """
  40. data = convert_to_literal_string(data)
  41. yaml = YAML()
  42. yaml.indent(mapping=2, sequence=4, offset=2)
  43. buffer = io.StringIO()
  44. yaml.dump(data, buffer)
  45. content = buffer.getvalue()
  46. header = DEMO_COMMENT.format(traj_path=traj_path)
  47. with open(file, "w") as f:
  48. f.write(f"{header}\n{content}")
  49. def convert_traj_to_action_demo(traj_path: str, output_file: str = None, include_user: bool = False):
  50. traj = json.load(open(traj_path))
  51. history = traj["history"]
  52. action_traj = list()
  53. admissable_roles = {"assistant", "user"} if include_user else {"assistant"}
  54. for step in history:
  55. if step['role'] in admissable_roles and step.get('agent', 'primary') == 'primary':
  56. action_traj.append({k: v for k, v in step.items() if k in {'content', 'role'}})
  57. save_demo(action_traj, output_file, traj_path)
  58. print(f"Saved demo to {output_file}")
  59. def main(traj_path: str, output_dir: str = None, suffix: str = "", overwrite: bool = False, include_user: bool = False):
  60. filename = '/'.join([Path(traj_path).parent.name + suffix, Path(traj_path).name.rsplit('.traj', 1)[0]]) + ".demo.yaml"
  61. output_file = Path(output_dir) / filename
  62. if output_file.exists() and not overwrite:
  63. raise FileExistsError(f"Output file already exists: {output_file}")
  64. output_file.parent.mkdir(parents=True, exist_ok=True)
  65. convert_traj_to_action_demo(traj_path, output_file, include_user)
  66. def string2bool(s):
  67. if s.lower() in {"true", "1"}:
  68. return True
  69. elif s.lower() in {"false", "0"}:
  70. return False
  71. else:
  72. raise ValueError(f"Invalid boolean string: {s}")
  73. if __name__ == "__main__":
  74. parser = ArgumentParser()
  75. parser.add_argument("traj_path", type=str, help="Path to trajectory file")
  76. parser.add_argument("--output_dir", type=str, help="Output directory for action demos", default="./demos")
  77. parser.add_argument("--suffix", type=str, help="Suffix for the output file", default="")
  78. parser.add_argument("--overwrite", type=string2bool, help="Overwrite existing files", default=False, nargs='?')
  79. parser.add_argument("--include_user", type=string2bool, help="Include user responses (computer)", default=False, nargs='?')
  80. args = parser.parse_args()
  81. main(**vars(args))