create_images.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import annotations
  2. import re
  3. import asyncio
  4. from .. import debug
  5. from ..typing import CreateResult, Messages
  6. from .types import BaseProvider, ProviderType
  7. from ..image import ImageResponse
  8. system_message = """
  9. You can generate images, pictures, photos or img with the DALL-E 3 image generator.
  10. To generate an image with a prompt, do this:
  11. <img data-prompt=\"keywords for the image\">
  12. Never use own image links. Don't wrap it in backticks.
  13. It is important to use a only a img tag with a prompt.
  14. <img data-prompt=\"image caption\">
  15. """
  16. class CreateImagesProvider(BaseProvider):
  17. """
  18. Provider class for creating images based on text prompts.
  19. This provider handles image creation requests embedded within message content,
  20. using provided image creation functions.
  21. Attributes:
  22. provider (ProviderType): The underlying provider to handle non-image related tasks.
  23. create_images (callable): A function to create images synchronously.
  24. create_images_async (callable): A function to create images asynchronously.
  25. system_message (str): A message that explains the image creation capability.
  26. include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
  27. __name__ (str): Name of the provider.
  28. url (str): URL of the provider.
  29. working (bool): Indicates if the provider is operational.
  30. supports_stream (bool): Indicates if the provider supports streaming.
  31. """
  32. def __init__(
  33. self,
  34. provider: ProviderType,
  35. create_images: callable,
  36. create_async: callable,
  37. system_message: str = system_message,
  38. include_placeholder: bool = True
  39. ) -> None:
  40. """
  41. Initializes the CreateImagesProvider.
  42. Args:
  43. provider (ProviderType): The underlying provider.
  44. create_images (callable): Function to create images synchronously.
  45. create_async (callable): Function to create images asynchronously.
  46. system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
  47. include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
  48. """
  49. self.provider = provider
  50. self.create_images = create_images
  51. self.create_images_async = create_async
  52. self.system_message = system_message
  53. self.include_placeholder = include_placeholder
  54. self.__name__ = provider.__name__
  55. self.url = provider.url
  56. self.working = provider.working
  57. self.supports_stream = provider.supports_stream
  58. def create_completion(
  59. self,
  60. model: str,
  61. messages: Messages,
  62. stream: bool = False,
  63. **kwargs
  64. ) -> CreateResult:
  65. """
  66. Creates a completion result, processing any image creation prompts found within the messages.
  67. Args:
  68. model (str): The model to use for creation.
  69. messages (Messages): The messages to process, which may contain image prompts.
  70. stream (bool, optional): Indicates whether to stream the results. Defaults to False.
  71. **kwargs: Additional keywordarguments for the provider.
  72. Yields:
  73. CreateResult: Yields chunks of the processed messages, including image data if applicable.
  74. Note:
  75. This method processes messages to detect image creation prompts. When such a prompt is found,
  76. it calls the synchronous image creation function and includes the resulting image in the output.
  77. """
  78. messages.insert(0, {"role": "system", "content": self.system_message})
  79. buffer = ""
  80. for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
  81. if isinstance(chunk, ImageResponse):
  82. yield chunk
  83. elif isinstance(chunk, str) and buffer or "<" in chunk:
  84. buffer += chunk
  85. if ">" in buffer:
  86. match = re.search(r'<img data-prompt="(.*?)">', buffer)
  87. if match:
  88. placeholder, prompt = match.group(0), match.group(1)
  89. start, append = buffer.split(placeholder, 1)
  90. if start:
  91. yield start
  92. if self.include_placeholder:
  93. yield placeholder
  94. if debug.logging:
  95. print(f"Create images with prompt: {prompt}")
  96. yield from self.create_images(prompt)
  97. if append:
  98. yield append
  99. else:
  100. yield buffer
  101. buffer = ""
  102. else:
  103. yield chunk
  104. async def create_async(
  105. self,
  106. model: str,
  107. messages: Messages,
  108. **kwargs
  109. ) -> str:
  110. """
  111. Asynchronously creates a response, processing any image creation prompts found within the messages.
  112. Args:
  113. model (str): The model to use for creation.
  114. messages (Messages): The messages to process, which may contain image prompts.
  115. **kwargs: Additional keyword arguments for the provider.
  116. Returns:
  117. str: The processed response string, including asynchronously generated image data if applicable.
  118. Note:
  119. This method processes messages to detect image creation prompts. When such a prompt is found,
  120. it calls the asynchronous image creation function and includes the resulting image in the output.
  121. """
  122. messages.insert(0, {"role": "system", "content": self.system_message})
  123. response = await self.provider.create_async(model, messages, **kwargs)
  124. matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
  125. results = []
  126. placeholders = []
  127. for placeholder, prompt in matches:
  128. if placeholder not in placeholders:
  129. if debug.logging:
  130. print(f"Create images with prompt: {prompt}")
  131. results.append(self.create_images_async(prompt))
  132. placeholders.append(placeholder)
  133. results = await asyncio.gather(*results)
  134. for idx, result in enumerate(results):
  135. placeholder = placeholder[idx]
  136. if self.include_placeholder:
  137. result = placeholder + result
  138. response = response.replace(placeholder, result)
  139. return response