utils.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # This file contains helpful utilities for the rest of the code, encompassing
  2. # parsing, environment variables, logging, etc.
  3. # TODO: switch to click
  4. import argparse
  5. import os
  6. import shutil
  7. import json
  8. from typing import Dict, Any, Optional
  9. import re
  10. # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  11. # logger = logging.getLogger(__name__)
  12. def parse_global_args():
  13. parser = argparse.ArgumentParser(description="Parse global parameters")
  14. parser.add_argument('--llm_name', type=str, default="gpt-4o-mini", help="Specify the LLM name of AIOS")
  15. parser.add_argument('--max_gpu_memory', type=json.loads, help="Max gpu memory allocated for the LLM")
  16. parser.add_argument('--eval_device', type=str, help="Evaluation device")
  17. parser.add_argument('--max_new_tokens', type=int, default=256, help="The maximum number of new tokens for generation")
  18. parser.add_argument("--agent_log_mode", type=str,default="console",choices=["console", "file"])
  19. parser.add_argument("--scheduler_log_mode", type=str,default="console",choices=["console", "file"])
  20. parser.add_argument("--llm_kernel_log_mode", type=str, default="console", choices=["console", "file"])
  21. parser.add_argument("--use_backend", type=str, default="ollama", choices=["ollama", "vllm"])
  22. return parser
  23. def extract_before_parenthesis(s: str) -> str:
  24. match = re.search(r'^(.*?)\([^)]*\)', s)
  25. return match.group(1) if match else s
  26. def get_from_dict_or_env(
  27. data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
  28. ) -> str:
  29. """Get a value from a dictionary or an environment variable."""
  30. if key in data and data[key]:
  31. return data[key]
  32. else:
  33. return get_from_env(key, env_key, default=default)
  34. def get_from_env(env_key: str, default: Optional[str] = None) -> str:
  35. """Get a value from an environment variable."""
  36. if env_key in os.environ and os.environ[env_key]:
  37. return os.environ[env_key]
  38. elif default is not None:
  39. return default
  40. else:
  41. raise ValueError(
  42. f"Did not find {env_key}, please add an environment variable"
  43. f" `{env_key}` which contains it. "
  44. )
  45. class Logger:
  46. def __init__(self, log_mode) -> None:
  47. self.log_mode = log_mode
  48. def log(self, info, path=None):
  49. if self.log_mode == "console":
  50. print(info)
  51. else:
  52. assert self.log_mode == "file"
  53. with open(path, "w") as w:
  54. w.write(info + "\n")
  55. def delete_directories(root_dir, target_dirs):
  56. """
  57. Recursively deletes directories with names in target_dirs starting from root_dir.
  58. """
  59. for dirpath, dirnames, filenames in os.walk(root_dir, topdown=False):
  60. for dirname in dirnames:
  61. if dirname in target_dirs:
  62. full_path = os.path.join(dirpath, dirname)
  63. # print(f"Deleting {full_path}...")
  64. shutil.rmtree(full_path, ignore_errors=True)
  65. def humanify_agent(input_string: str):
  66. """ turns 'author/example_agent' into 'Example Agent' """
  67. last_part = input_string.split('/')[-1].replace('_', ' ')
  68. return ' '.join(word.capitalize() for word in last_part.split())