feat(Blackbox): add image generation support and enhance response handling

This commit is contained in:
kqlio67 2024-09-05 09:58:28 +03:00
parent 55e55d77a2
commit f2f04a00b1

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import uuid import uuid
import secrets import secrets
import re import re
from aiohttp import ClientSession, ClientResponse import base64
from aiohttp import ClientSession
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
from ..typing import AsyncResult, Messages, ImageType from ..typing import AsyncResult, Messages, ImageType
from ..image import to_data_uri from ..image import to_data_uri, ImageResponse
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
class Blackbox(AsyncGeneratorProvider, ProviderModelMixin): class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
@ -20,12 +21,25 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
"llama-3.1-8b", "llama-3.1-8b",
'llama-3.1-70b', 'llama-3.1-70b',
'llama-3.1-405b', 'llama-3.1-405b',
'ImageGeneration',
] ]
model_aliases = { model_aliases = {
"gemini-flash": "gemini-1.5-flash", "gemini-flash": "gemini-1.5-flash",
} }
agent_mode_map = {
'ImageGeneration': {"mode": True, "id": "ImageGenerationLV45LJp", "name": "Image Generation"},
}
model_id_map = {
"blackbox": {},
"gemini-1.5-flash": {'mode': True, 'id': 'Gemini'},
"llama-3.1-8b": {'mode': True, 'id': "llama-3.1-8b"},
'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"}
}
@classmethod @classmethod
def get_model(cls, model: str) -> str: def get_model(cls, model: str) -> str:
if model in cls.models: if model in cls.models:
@ -35,6 +49,15 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
else: else:
return cls.default_model return cls.default_model
@classmethod
async def download_image_to_base64_url(cls, url: str) -> str:
async with ClientSession() as session:
async with session.get(url) as response:
image_data = await response.read()
base64_data = base64.b64encode(image_data).decode('utf-8')
mime_type = response.headers.get('Content-Type', 'image/jpeg')
return f"data:{mime_type};base64,{base64_data}"
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
cls, cls,
@ -44,7 +67,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
image: Optional[ImageType] = None, image: Optional[ImageType] = None,
image_name: Optional[str] = None, image_name: Optional[str] = None,
**kwargs **kwargs
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[AsyncResult, None]:
if image is not None: if image is not None:
messages[-1]["data"] = { messages[-1]["data"] = {
"fileText": image_name, "fileText": image_name,
@ -72,20 +95,12 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
model = cls.get_model(model) # Resolve the model alias model = cls.get_model(model) # Resolve the model alias
model_id_map = {
"blackbox": {},
"gemini-1.5-flash": {'mode': True, 'id': 'Gemini'},
"llama-3.1-8b": {'mode': True, 'id': "llama-3.1-8b"},
'llama-3.1-70b': {'mode': True, 'id': "llama-3.1-70b"},
'llama-3.1-405b': {'mode': True, 'id': "llama-3.1-405b"}
}
data = { data = {
"messages": messages, "messages": messages,
"id": random_id, "id": random_id,
"userId": random_user_id, "userId": random_user_id,
"codeModelMode": True, "codeModelMode": True,
"agentMode": {}, "agentMode": cls.agent_mode_map.get(model, {}),
"trendingAgentMode": {}, "trendingAgentMode": {},
"isMicMode": False, "isMicMode": False,
"isChromeExt": False, "isChromeExt": False,
@ -93,7 +108,7 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
"webSearchMode": False, "webSearchMode": False,
"userSystemPrompt": "", "userSystemPrompt": "",
"githubToken": None, "githubToken": None,
"trendingAgentModel": model_id_map.get(model, {}), # Default to empty dict if model not found "trendingAgentModel": cls.model_id_map.get(model, {}),
"maxTokens": None "maxTokens": None
} }
@ -101,9 +116,41 @@ class Blackbox(AsyncGeneratorProvider, ProviderModelMixin):
f"{cls.url}/api/chat", json=data, proxy=proxy f"{cls.url}/api/chat", json=data, proxy=proxy
) as response: ) as response:
response.raise_for_status() response.raise_for_status()
full_response = ""
buffer = ""
image_base64_url = None
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
if chunk: if chunk:
# Decode the chunk and clean up unwanted prefixes using a regex
decoded_chunk = chunk.decode() decoded_chunk = chunk.decode()
cleaned_chunk = re.sub(r'\$@\$.+?\$@\$|\$@\$', '', decoded_chunk) cleaned_chunk = re.sub(r'\$@\$.+?\$@\$|\$@\$', '', decoded_chunk)
yield cleaned_chunk
buffer += cleaned_chunk
# Check if there's a complete image line in the buffer
image_match = re.search(r'!\[Generated Image\]\((https?://[^\s\)]+)\)', buffer)
if image_match:
image_url = image_match.group(1)
# Download the image and convert to base64 URL
image_base64_url = await cls.download_image_to_base64_url(image_url)
# Remove the image line from the buffer
buffer = re.sub(r'!\[Generated Image\]\(https?://[^\s\)]+\)', '', buffer)
# Send text line by line
lines = buffer.split('\n')
for line in lines[:-1]:
if line.strip():
full_response += line + '\n'
yield line + '\n'
buffer = lines[-1] # Keep the last incomplete line in the buffer
# Send the remaining buffer if it's not empty
if buffer.strip():
full_response += buffer
yield buffer
# If an image was found, send it as ImageResponse
if image_base64_url:
alt_text = "Generated Image"
image_response = ImageResponse(image_base64_url, alt=alt_text)
yield image_response