commands.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import re
  2. import yaml
  3. from abc import abstractmethod
  4. from dataclasses import dataclass
  5. from pathlib import Path
  6. from typing import Dict, List, Optional
  7. from simple_parsing.helpers import FrozenSerializable
  8. @dataclass(frozen=True)
  9. class AssistantMetadata(FrozenSerializable):
  10. """Pass observations to the assistant, and get back a response."""
  11. system_template: Optional[str] = None
  12. instance_template: Optional[str] = None
  13. # TODO: first can be used for two-stage actions
  14. # TODO: eventually might control high-level control flow
  15. @dataclass(frozen=True)
  16. class ControlMetadata(FrozenSerializable):
  17. """TODO: should be able to control high-level control flow after calling this command"""
  18. next_step_template: Optional[str] = None
  19. next_step_action_template: Optional[str] = None
  20. @dataclass(frozen=True)
  21. class Command(FrozenSerializable):
  22. code: str
  23. name: str
  24. docstring: Optional[str] = None
  25. end_name: Optional[str] = None # if there is an end_name, then it is a multi-line command
  26. arguments: Optional[Dict] = None
  27. signature: Optional[str] = None
  28. class ParseCommandMeta(type):
  29. _registry = {}
  30. def __new__(cls, name, bases, attrs):
  31. new_cls = super().__new__(cls, name, bases, attrs)
  32. if name != "ParseCommand":
  33. cls._registry[name] = new_cls
  34. return new_cls
  35. @dataclass
  36. class ParseCommand(metaclass=ParseCommandMeta):
  37. @classmethod
  38. def get(cls, name):
  39. try:
  40. return cls._registry[name]()
  41. except KeyError:
  42. raise ValueError(f"Command parser ({name}) not found.")
  43. @abstractmethod
  44. def parse_command_file(self, path: str) -> List[Command]:
  45. """
  46. Define how to parse a file into a list of commands.
  47. """
  48. raise NotImplementedError
  49. @abstractmethod
  50. def generate_command_docs(self, commands: List[Command], subroutine_types, **kwargs) -> str:
  51. """
  52. Generate a string of documentation for the given commands and subroutine types.
  53. """
  54. raise NotImplementedError
  55. # DEFINE NEW COMMAND PARSER FUNCTIONS BELOW THIS LINE
  56. class ParseCommandBash(ParseCommand):
  57. def parse_command_file(self, path: str) -> List[Command]:
  58. print('Parsing command file:', path)
  59. contents = open(path, "r").read()
  60. if contents.strip().startswith("#!"):
  61. commands = self.parse_script(path, contents)
  62. else:
  63. if not path.endswith(".sh") and not Path(path).name.startswith("_"):
  64. raise ValueError((
  65. f"Source file {path} does not have a .sh extension.\n"
  66. "Only .sh files are supported for bash function parsing.\n"
  67. "If you want to use a non-shell file as a command (script), "
  68. "it should use a shebang (e.g. #!/usr/bin/env python)."
  69. ))
  70. return self.parse_bash_functions(path, contents)
  71. if len(commands) == 0 and not Path(path).name.startswith("_"):
  72. raise ValueError((
  73. f"Non-shell file {path} does not contain any commands.\n"
  74. "If you want to use a non-shell file as a command (script), "
  75. "it should contain exactly one @yaml docstring. "
  76. "If you want to use a file as a utility script, "
  77. "it should start with an underscore (e.g. _utils.py)."
  78. ))
  79. else:
  80. return commands
  81. def parse_bash_functions(self, path, contents) -> List[Command]:
  82. """
  83. Simple logic for parsing a bash file and segmenting it into functions.
  84. Assumes that all functions have their name and opening curly bracket in one line,
  85. and closing curly bracket in a line by itself.
  86. """
  87. lines = contents.split("\n")
  88. commands = []
  89. idx = 0
  90. docs = []
  91. while idx < len(lines):
  92. line = lines[idx]
  93. idx += 1
  94. if line.startswith("# "):
  95. docs.append(line[2:])
  96. elif line.strip().endswith("() {"):
  97. name = line.split()[0][:-2]
  98. code = line
  99. while lines[idx].strip() != "}":
  100. code += lines[idx]
  101. idx += 1
  102. code += lines[idx]
  103. docstring, end_name, arguments, signature = None, None, None, name
  104. docs_dict = yaml.safe_load("\n".join(docs).replace('@yaml', ''))
  105. if docs_dict is not None:
  106. docstring = docs_dict["docstring"]
  107. end_name = docs_dict.get("end_name", None)
  108. arguments = docs_dict.get("arguments", None)
  109. if "signature" in docs_dict:
  110. signature = docs_dict["signature"]
  111. else:
  112. if arguments is not None:
  113. for param, settings in arguments.items():
  114. if settings["required"]:
  115. signature += f" <{param}>"
  116. else:
  117. signature += f" [<{param}>]"
  118. command = Command.from_dict({
  119. "code": code,
  120. "docstring": docstring,
  121. "end_name": end_name,
  122. "name": name,
  123. "arguments": arguments,
  124. "signature": signature
  125. })
  126. commands.append(command)
  127. docs = []
  128. return commands
  129. def parse_script(self, path, contents) -> List[Command]:
  130. pattern = re.compile(r'^#\s*@yaml\s*\n^#.*(?:\n#.*)*', re.MULTILINE)
  131. matches = pattern.findall(contents)
  132. if len(matches) == 0:
  133. return []
  134. elif len(matches) > 1:
  135. raise ValueError((
  136. "Non-shell file contains multiple @yaml tags.\n"
  137. "Only one @yaml tag is allowed per script."
  138. ))
  139. else:
  140. yaml_content = matches[0]
  141. yaml_content = re.sub(r'^#', '', yaml_content, flags=re.MULTILINE)
  142. docs_dict = yaml.safe_load(yaml_content.replace('@yaml', ''))
  143. assert docs_dict is not None
  144. docstring = docs_dict["docstring"]
  145. end_name = docs_dict.get("end_name", None)
  146. arguments = docs_dict.get("arguments", None)
  147. signature = docs_dict.get("signature", None)
  148. name = Path(path).name.rsplit(".", 1)[0]
  149. if signature is None and arguments is not None:
  150. signature = name
  151. for param, settings in arguments.items():
  152. if settings["required"]:
  153. signature += f" <{param}>"
  154. else:
  155. signature += f" [<{param}>]"
  156. code = contents
  157. return [Command.from_dict({
  158. "code": code,
  159. "docstring": docstring,
  160. "end_name": end_name,
  161. "name": name,
  162. "arguments": arguments,
  163. "signature": signature
  164. })]
  165. def generate_command_docs(self, commands: List[Command], subroutine_types, **kwargs) -> str:
  166. docs = ""
  167. for cmd in commands:
  168. if cmd.docstring is not None:
  169. docs += f"{cmd.signature or cmd.name} - {cmd.docstring.format(**kwargs)}\n"
  170. for subroutine in subroutine_types:
  171. if subroutine.docstring is not None:
  172. docs += f"{subroutine.signature or subroutine.name} - {subroutine.docstring.format(**kwargs)}\n"
  173. return docs
  174. class ParseCommandDetailed(ParseCommandBash):
  175. """
  176. # command_name:
  177. # "docstring"
  178. # signature: "signature"
  179. # arguments:
  180. # arg1 (type) [required]: "description"
  181. # arg2 (type) [optional]: "description"
  182. """
  183. def get_signature(cmd):
  184. signature = cmd.name
  185. if "arguments" in cmd.__dict__ and cmd.arguments is not None:
  186. if cmd.end_name is None:
  187. for param, settings in cmd.arguments.items():
  188. if settings["required"]:
  189. signature += f" <{param}>"
  190. else:
  191. signature += f" [<{param}>]"
  192. else:
  193. for param, settings in list(cmd.arguments.items())[:-1]:
  194. if settings["required"]:
  195. signature += f" <{param}>"
  196. else:
  197. signature += f" [<{param}>]"
  198. signature += f"\n{list(cmd.arguments[-1].keys())[0]}\n{cmd.end_name}"
  199. return signature
  200. def generate_command_docs(
  201. self,
  202. commands: List[Command],
  203. subroutine_types,
  204. **kwargs,
  205. ) -> str:
  206. docs = ""
  207. for cmd in commands + subroutine_types:
  208. docs += f"{cmd.name}:\n"
  209. if cmd.docstring is not None:
  210. docs += f" docstring: {cmd.docstring}\n"
  211. if cmd.signature is not None:
  212. docs += f" signature: {cmd.signature}\n"
  213. else:
  214. docs += f" signature: {self.get_signature(cmd)}\n"
  215. if "arguments" in cmd.__dict__ and cmd.arguments is not None:
  216. docs += " arguments:\n"
  217. for param, settings in cmd.arguments.items():
  218. req_string = "required" if settings["required"] else "optional"
  219. docs += f" - {param} ({settings['type']}) [{req_string}]: {settings['description']}\n"
  220. docs += "\n"
  221. return docs