cli.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import argparse
  2. from simple_colors import *
  3. from exec_engine.pipeline import *
  4. from exec_engine.credentials.credentials_utils import *
  5. from halo import Halo
  6. import os
  7. from pathlib import Path
  8. from authorizations.scripts.authorization_utils import authorize_service
  9. from main import ExecutionEngine, PythonAPIExecutor
  10. from exec_engine.utils import SQL_Type, Filesystem_Type
  11. from exec_engine.db_manager import MySQLManager, SQLiteManager
  12. from dotenv import load_dotenv
  13. import questionary
  14. GORILLA_EMOJI = "🦍 "
  15. SUCCESS = u'\u2713'
  16. USER_CONFIG_PATH = os.path.join(os.path.dirname(Path(os.path.realpath(__file__))), "user_config.json")
  17. default_config = {'max_attempt' : 1,
  18. 'option': 2,
  19. 'show_debug': True,
  20. 'model': "gpt-4-turbo-preview",
  21. 'undo': True,
  22. 'dbtype': 'sqlite',
  23. 'lfs_limit': 200,
  24. 'fs_path': ""
  25. }
  26. def insert_callback(service, key):
  27. print(SUCCESS)
  28. with Halo(text=f"{GORILLA_EMOJI}inserting creds...", spinner="dots"):
  29. insert_creds(service, key, target = CREDS_FOLDER_PATH, cred_type="raw")
  30. def list_callback():
  31. print(list_creds(target = CREDS_FOLDER_PATH))
  32. def restful_callback(prompt, generate_mode):
  33. engine = ExecutionEngine(generate_mode=generate_mode)
  34. # Specify a negation manager (e.g NaiveNegationAPIPairManager) when
  35. # initializing PythonAPIExecutor to record the user feedback for the negated API
  36. engine.api_executor = PythonAPIExecutor(engine.docker_sandbox)
  37. creds, services = engine.api_executor.prepare_credentials(prompt)
  38. if not creds:
  39. for service in list_supported_services():
  40. if service in prompt:
  41. print('Warning: detect keyword {service} but {service} has not been authorized'.format(service=service))
  42. if not services:
  43. api_string_extension = ""
  44. else:
  45. api_string_extension = "from {services} API...".format(services=services)
  46. with Halo(text=f"{GORILLA_EMOJI}fetching response {api_string_extension}".format(api_string_extension), spinner="dots"):
  47. response, forward_call, backward_call = prompt_execute(
  48. engine, prompt, services=services, creds=creds, max_attempt=get_config('max_attempt'), model=get_config('model'))
  49. if response['output']:
  50. print('\n', '\n'.join(response["output"][0]))
  51. else:
  52. print('\n', "execution failed with the following debug messages:")
  53. print('\n', response['debug'])
  54. return
  55. if default_config["undo"]:
  56. answer = questionary.select("Are you sure you want to keep the changes", choices=["Commit", "Undo"]).ask()
  57. if answer.lower() == "undo":
  58. # if there is a match with print_pattern, it's highly likely that it is only a print message
  59. print_pattern = r'^\s*print\s*\((?:.|\n)*\)\s*$'
  60. matches = re.search(print_pattern, backward_call, re.DOTALL)
  61. if matches:
  62. print(engine.api_executor.execute_api_call(backward_call, services)["output"])
  63. else:
  64. print("Warning: the undo feature is still in beta mode and may cause irreversible changes\n" +
  65. "Gorilla will execute the following code:\n{}".format(backward_call))
  66. confirmation = questionary.select("", choices=["Confirm Undo", "Cancel Undo"]).ask()
  67. if confirmation == "Confirm Undo":
  68. print(engine.api_executor.execute_api_call(backward_call, services)["output"])
  69. else:
  70. print("Abort undo, execution completed!")
  71. if engine.api_executor.negation_manager != None:
  72. feedback = questionary.select("How would you rate the suggested negation API?",
  73. choices=["Correct", "Incorrect", "Skip"]).ask()
  74. if feedback == "Correct":
  75. engine.api_executor.negation_manager.insert_log(forward_call, backward_call, True)
  76. elif feedback == "Incorrect":
  77. engine.api_executor.negation_manager.insert_log(forward_call, backward_call, False)
  78. print("Execution Completed!")
  79. def initialize_user_config():
  80. if os.path.exists(USER_CONFIG_PATH):
  81. return
  82. with open(USER_CONFIG_PATH, 'w') as j:
  83. json.dump(default_config, j)
  84. print("Config file created successfully.")
  85. def update_user_config(key, value):
  86. with open(USER_CONFIG_PATH, 'r') as j:
  87. oldconfig = json.load(j)
  88. if key == 'max_attempt' or key == 'option' or key == 'lfs_limit':
  89. value = int(value)
  90. elif key == 'show_debug':
  91. value = value.lower() == 'true'
  92. elif key == 'fs_path':
  93. value = os.path.join(os.getcwd(), value)
  94. if not os.path.exists(value) and not os.path.isdir(value):
  95. print("Please make sure you enter a valid directory path!")
  96. return
  97. modified = False
  98. if oldconfig[key] != value:
  99. modified = True
  100. oldconfig[key] = value
  101. if modified:
  102. with open(USER_CONFIG_PATH, 'w') as j:
  103. json.dump(oldconfig, j)
  104. print("Config file modified successfully.")
  105. def get_config(key):
  106. with open(USER_CONFIG_PATH, 'r') as j:
  107. config = json.load(j)
  108. return config[key]
  109. def authorize_callback(services):
  110. supported_services = list_supported_services()
  111. for service in services:
  112. if service in supported_services:
  113. try:
  114. authorize_service(service)
  115. except Exception as e:
  116. print(e)
  117. print("Failed to authorize user's {service} account".format(service=service))
  118. else:
  119. print("{service} is currently not supported".format(service=service))
  120. def fs_callback(prompt, generate_mode):
  121. path = get_config('fs_path')
  122. if not path:
  123. path = os.getcwd()
  124. path = os.path.abspath(path)
  125. engine = ExecutionEngine(path=path, generate_mode=generate_mode)
  126. option = get_config('option')
  127. engine.initialize_fs(debug_path=path, git_init=option == 2)
  128. if option == 1:
  129. engine.set_dry_run(Filesystem_Type, True)
  130. else:
  131. engine.set_dry_run(Filesystem_Type, False)
  132. api_call, neg_api_call = engine.gen_api_pair(prompt, Filesystem_Type, None, model=get_config('model'))
  133. print(black("Do you want to execute the following filesystem command...", 'bold') + '\n' + magenta(api_call, 'bold'))
  134. answer = questionary.select("",
  135. choices=["Yes", "No"]).ask()
  136. if answer == "No":
  137. print("Execution abandoned.")
  138. return
  139. try:
  140. engine.exec_api_call(api_call=api_call, api_type=Filesystem_Type, debug_neg=neg_api_call)
  141. except RuntimeError as e :
  142. print(f"Execution Failed: {e}")
  143. return
  144. option_to_method = {
  145. 1: 'negation call',
  146. 2: 'git reset'
  147. }
  148. answer = questionary.select("Are you sure you want to keep the changes",
  149. choices=["Commit", "Undo" + " ({})".
  150. format(option_to_method[option])]).ask()
  151. if option == 2:
  152. if answer == "Commit":
  153. commit_msg = questionary.text("Enter a commit message [Optional]: ").ask()
  154. engine.commit_api_call(Filesystem_Type, commit_msg)
  155. print("Execution commited.")
  156. else:
  157. engine.undo_api_call(Filesystem_Type, option=option)
  158. print("Execution undone.")
  159. else:
  160. if answer == "Commit":
  161. print("Execution completed.")
  162. else:
  163. engine.exec_api_call(neg_api_call, api_type=Filesystem_Type)
  164. print("Execution undone.")
  165. def remove_creds_callback(services):
  166. try:
  167. remove_creds(services)
  168. except Exception as e:
  169. print(e)
  170. print("An unexpected error occured while removing credentials")
  171. def db_callback(prompt, generate_mode):
  172. config = {
  173. 'user': os.environ.get('DATABASE_USER'),
  174. 'password': os.environ.get('DATABASE_PASSWORD'),
  175. 'host': os.environ.get('DATABASE_HOST'),
  176. 'database': os.environ.get('DATABASE_NAME'),
  177. 'path': os.environ.get('DATABASE_PATH')
  178. }
  179. engine = ExecutionEngine(generate_mode=generate_mode)
  180. db_type = get_config('dbtype')
  181. db_manager = None
  182. try:
  183. if db_type == 'mysql':
  184. db_manager = MySQLManager(config, docker_sandbox=engine.docker_sandbox)
  185. elif db_type == 'sqlite':
  186. db_manager = SQLiteManager(config, docker_sandbox=engine.docker_sandbox)
  187. except Exception as e:
  188. print(f"Error during {db_type} Manager Init: {e}")
  189. return
  190. db_manager.connect()
  191. option = get_config('option')
  192. if option == 1:
  193. engine.set_dry_run(SQL_Type, True)
  194. else:
  195. engine.set_dry_run(SQL_Type, False)
  196. engine.initialize_db(debug_manager=db_manager)
  197. api_call, neg_api_call = engine.gen_api_pair(prompt, SQL_Type, None, model=get_config('model'))
  198. if neg_api_call == None and option == 1:
  199. print("Error: option 1 requires negation API call. neg_api_call is None.")
  200. return
  201. print(black("Do you want to execute the following database command...", 'bold') + '\n' + magenta(api_call, 'bold'))
  202. answer = questionary.select("",
  203. choices=["Yes", "No"]).ask()
  204. if answer == "No":
  205. print("Execution abandoned.")
  206. return
  207. try:
  208. engine.exec_api_call(api_call=api_call, api_type=SQL_Type, debug_neg=neg_api_call)
  209. if option == 1:
  210. engine.commit_api_call(SQL_Type)
  211. except RuntimeError as e :
  212. print(f"Execution Failed: {e}")
  213. return
  214. option_to_method = {
  215. 1: 'negation call',
  216. 2: 'db rollback'
  217. }
  218. answer = questionary.select("Are you sure you want to keep the changes",
  219. choices=["Commit", "Undo" + " ({})".
  220. format(option_to_method[option])]).ask()
  221. if option == 2:
  222. if answer == "Commit":
  223. engine.commit_api_call(SQL_Type)
  224. print("Execution commited.")
  225. else:
  226. engine.undo_api_call(SQL_Type, option=option)
  227. print("Execution undone.")
  228. else:
  229. if answer == "Commit":
  230. print("Execution completed!")
  231. else:
  232. engine.exec_api_call(neg_api_call, api_type=SQL_Type)
  233. engine.commit_api_call(SQL_Type)
  234. print("Execution undone.")
  235. def exit_with_help_message(parser):
  236. print(green("To execute a prompt with a specified execution type", ['bold']))
  237. # retrieve subparsers from parser
  238. subparsers_actions = [
  239. action for action in parser._actions
  240. if isinstance(action, argparse._SubParsersAction)]
  241. # there will probably only be one subparser_action,
  242. # but better save than sorry
  243. for subparsers_action in subparsers_actions:
  244. # get all subparsers and print help
  245. for choice, subparser in subparsers_action.choices.items():
  246. print(subparser.format_help())
  247. print(green("To perform other Gorilla-x operations", ['bold']))
  248. parser.print_help()
  249. parser.exit()
  250. class _HelpAction(argparse._HelpAction):
  251. def __call__(self, parser, namespace, values, option_string=None):
  252. exit_with_help_message(parser)
  253. class ArgumentParser(argparse.ArgumentParser):
  254. def error(self, message):
  255. exit_with_help_message(self)
  256. def main():
  257. initialize_user_config()
  258. parser = ArgumentParser(add_help=False)
  259. subparser = parser.add_subparsers(dest='command')
  260. execute_parser = subparser.add_parser("execute", add_help=False)
  261. execute_parser.add_argument("-prompt", nargs='*', metavar='prompt', help="prompt for Gorilla-x to execute")
  262. execute_parser.add_argument("-type", nargs=1, metavar='type', help="specify the execution type as either 'rest', 'db', or 'fs'")
  263. execute_parser.add_argument("-generate_mode", metavar='gen_mode', help="specify how to use the LLM, either 'default', 'function_in_context' or 'function_calling_native'", default='default')
  264. parser.add_argument('--help', action=_HelpAction)
  265. parser.add_argument('-insert_creds', action='store', metavar=('service', 'key'), nargs=2, help="Inserts the service-key pair to Gorilla's secret store.")
  266. parser.add_argument('-list_creds', action='store_true', help="Lists all the currently registered credentials.")
  267. parser.add_argument('-authorize', action='store', metavar='service', nargs=1, help="Perform OAuth2 authorization and retrieve access token from the service")
  268. parser.add_argument('-remove_creds', action='extend', metavar='service', nargs="+", help="Removes previously authorized credentials. Enter ALL as parameter to delete all creds")
  269. parser.add_argument('-set_config', action='store', metavar=('name', 'value'), nargs=2, help="Updates the user config with the corresponding key value pairs")
  270. try:
  271. args = parser.parse_args()
  272. except argparse.ArgumentError:
  273. parser.print_help()
  274. return
  275. # load the environment variables
  276. load_dotenv()
  277. if args.command == "execute":
  278. if args.prompt and args.type:
  279. prompt = " ".join(args.prompt)
  280. apitype = args.type[0].lower()
  281. if "rest" in apitype:
  282. restful_callback(prompt, args.generate_mode)
  283. elif "db" in apitype:
  284. db_callback(prompt, args.generate_mode)
  285. elif "fs" in apitype:
  286. fs_callback(prompt, args.generate_mode)
  287. else:
  288. print("Error: invalid execution type. The execution types Gorilla-x currently support are: \n" +
  289. " 1. RESTful (rest)\n" +
  290. " 2. Database (db)\n" +
  291. " 3. Filesystem (fs)")
  292. return
  293. else:
  294. print("execute requires -prompt and -type to be provided")
  295. return
  296. if args.authorize:
  297. authorize_callback(args.authorize)
  298. elif args.remove_creds:
  299. remove_creds_callback(args.remove_creds)
  300. elif args.list_creds:
  301. list_callback()
  302. elif args.insert_creds:
  303. key = args.insert_creds[0]
  304. path = args.insert_creds[1]
  305. insert_callback(key, path)
  306. elif args.set_config:
  307. key = args.set_config[0]
  308. value = args.set_config[1]
  309. if key.lower() == 'max_attempt':
  310. if not value.isnumeric():
  311. print("Please enter a positive integer.")
  312. return
  313. else:
  314. value = int(value)
  315. key = 'max_attempt'
  316. elif key.lower() == 'model':
  317. if value.isdigit():
  318. print("Please enter a valid model version.")
  319. else:
  320. value = value.lower()
  321. key = "model"
  322. update_user_config(key, value)
  323. else:
  324. exit_with_help_message(parser)
  325. if __name__ == "__main__":
  326. main()