client_utils.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from abc import ABC
  2. from typing import Any
  3. from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
  4. from openai import AzureOpenAI, OpenAI
  5. import logging
  6. from env_config import read_env_config, set_env
  7. from os import environ, getenv
  8. import time
  9. from threading import Lock
  10. from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
  11. from azure.identity import get_bearer_token_provider
  12. logger = logging.getLogger("client_utils")
  13. def build_openai_client(env_prefix : str = "COMPLETION", **kwargs: Any) -> OpenAI:
  14. """
  15. Build OpenAI client based on the environment variables.
  16. """
  17. kwargs = _remove_empty_values(kwargs)
  18. env = read_env_config(env_prefix)
  19. with set_env(**env):
  20. if is_azure():
  21. auth_args = _get_azure_auth_client_args()
  22. client = AzureOpenAI(**auth_args, **kwargs)
  23. else:
  24. client = OpenAI(**kwargs)
  25. return client
  26. def build_langchain_embeddings(**kwargs: Any) -> OpenAIEmbeddings:
  27. """
  28. Build OpenAI embeddings client based on the environment variables.
  29. """
  30. kwargs = _remove_empty_values(kwargs)
  31. env = read_env_config("EMBEDDING")
  32. with set_env(**env):
  33. if is_azure():
  34. auth_args = _get_azure_auth_client_args()
  35. client = AzureOpenAIEmbeddings(**auth_args, **kwargs)
  36. else:
  37. client = OpenAIEmbeddings(**kwargs)
  38. return client
  39. def _remove_empty_values(d: dict) -> dict:
  40. return {k: v for k, v in d.items() if v is not None}
  41. def _get_azure_auth_client_args() -> dict:
  42. """Handle Azure OpenAI Keyless, Managed Identity and Key based authentication
  43. https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521
  44. """
  45. client_args = {}
  46. if getenv("AZURE_OPENAI_KEY"):
  47. logger.info("Using Azure OpenAI Key based authentication")
  48. client_args["api_key"] = getenv("AZURE_OPENAI_KEY")
  49. else:
  50. if client_id := getenv("AZURE_OPENAI_CLIENT_ID"):
  51. # Authenticate using a user-assigned managed identity on Azure
  52. logger.info("Using Azure OpenAI Managed Identity Keyless authentication")
  53. azure_credential = ManagedIdentityCredential(client_id=client_id)
  54. else:
  55. # Authenticate using the default Azure credential chain
  56. logger.info("Using Azure OpenAI Default Azure Credential Keyless authentication")
  57. azure_credential = DefaultAzureCredential()
  58. client_args["azure_ad_token_provider"] = get_bearer_token_provider(
  59. azure_credential, "https://cognitiveservices.azure.com/.default")
  60. client_args["api_version"] = getenv("AZURE_OPENAI_API_VERSION") or "2024-02-15-preview"
  61. client_args["azure_endpoint"] = getenv("AZURE_OPENAI_ENDPOINT")
  62. client_args["azure_deployment"] = getenv("AZURE_OPENAI_DEPLOYMENT")
  63. return client_args
  64. def is_azure():
  65. azure = "AZURE_OPENAI_ENDPOINT" in environ or "AZURE_OPENAI_KEY" in environ or "AZURE_OPENAI_AD_TOKEN" in environ
  66. if azure:
  67. logger.debug("Using Azure OpenAI environment variables")
  68. else:
  69. logger.debug("Using OpenAI environment variables")
  70. return azure
  71. def safe_min(a: Any, b: Any) -> Any:
  72. if a is None:
  73. return b
  74. if b is None:
  75. return a
  76. return min(a, b)
  77. def safe_max(a: Any, b: Any) -> Any:
  78. if a is None:
  79. return b
  80. if b is None:
  81. return a
  82. return max(a, b)
  83. class UsageStats:
  84. def __init__(self) -> None:
  85. self.start = time.time()
  86. self.completion_tokens = 0
  87. self.prompt_tokens = 0
  88. self.total_tokens = 0
  89. self.end = None
  90. self.duration = 0
  91. self.calls = 0
  92. def __add__(self, other: 'UsageStats') -> 'UsageStats':
  93. stats = UsageStats()
  94. stats.start = safe_min(self.start, other.start)
  95. stats.end = safe_max(self.end, other.end)
  96. stats.completion_tokens = self.completion_tokens + other.completion_tokens
  97. stats.prompt_tokens = self.prompt_tokens + other.prompt_tokens
  98. stats.total_tokens = self.total_tokens + other.total_tokens
  99. stats.duration = self.duration + other.duration
  100. stats.calls = self.calls + other.calls
  101. return stats
  102. class StatsCompleter(ABC):
  103. def __init__(self, create_func):
  104. self.create_func = create_func
  105. self.stats = None
  106. self.lock = Lock()
  107. def __call__(self, *args: Any, **kwds: Any) -> Any:
  108. response = self.create_func(*args, **kwds)
  109. self.lock.acquire()
  110. try:
  111. if not self.stats:
  112. self.stats = UsageStats()
  113. self.stats.completion_tokens += response.usage.completion_tokens
  114. self.stats.prompt_tokens += response.usage.prompt_tokens
  115. self.stats.total_tokens += response.usage.total_tokens
  116. self.stats.calls += 1
  117. return response
  118. finally:
  119. self.lock.release()
  120. def get_stats_and_reset(self) -> UsageStats:
  121. self.lock.acquire()
  122. try:
  123. end = time.time()
  124. stats = self.stats
  125. if stats:
  126. stats.end = end
  127. stats.duration = end - self.stats.start
  128. self.stats = None
  129. return stats
  130. finally:
  131. self.lock.release()
  132. class ChatCompleter(StatsCompleter):
  133. def __init__(self, client):
  134. super().__init__(client.chat.completions.create)
  135. class CompletionsCompleter(StatsCompleter):
  136. def __init__(self, client):
  137. super().__init__(client.completions.create)