|
|
|
@ -3,6 +3,9 @@ from __future__ import annotations
|
|
|
|
|
import time
|
|
|
|
|
import random
|
|
|
|
|
import string
|
|
|
|
|
import asyncio
|
|
|
|
|
import base64
|
|
|
|
|
from aiohttp import ClientSession, BaseConnector
|
|
|
|
|
|
|
|
|
|
from .types import Client as BaseClient
|
|
|
|
|
from .types import ProviderType, FinishReason
|
|
|
|
@ -11,9 +14,11 @@ from .types import AsyncIterResponse, ImageProvider
|
|
|
|
|
from .image_models import ImageModels
|
|
|
|
|
from .helper import filter_json, find_stop, filter_none, cast_iter_async
|
|
|
|
|
from .service import get_last_provider, get_model_and_provider
|
|
|
|
|
from ..Provider import ProviderUtils
|
|
|
|
|
from ..typing import Union, Messages, AsyncIterator, ImageType
|
|
|
|
|
from ..errors import NoImageResponseError
|
|
|
|
|
from ..image import ImageResponse as ImageProviderResponse
|
|
|
|
|
from ..errors import NoImageResponseError, ProviderNotFoundError
|
|
|
|
|
from ..requests.aiohttp import get_connector
|
|
|
|
|
from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
anext
|
|
|
|
@ -156,12 +161,28 @@ class Chat():
|
|
|
|
|
def __init__(self, client: AsyncClient, provider: ProviderType = None):
|
|
|
|
|
self.completions = Completions(client, provider)
|
|
|
|
|
|
|
|
|
|
async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
|
|
|
|
|
async def iter_image_response(
|
|
|
|
|
response: AsyncIterator,
|
|
|
|
|
response_format: str = None,
|
|
|
|
|
connector: BaseConnector = None,
|
|
|
|
|
proxy: str = None
|
|
|
|
|
) -> Union[ImagesResponse, None]:
|
|
|
|
|
async for chunk in response:
|
|
|
|
|
if isinstance(chunk, ImageProviderResponse):
|
|
|
|
|
return ImagesResponse([Image(image) for image in chunk.get_list()])
|
|
|
|
|
if response_format == "b64_json":
|
|
|
|
|
async with ClientSession(
|
|
|
|
|
connector=get_connector(connector, proxy)
|
|
|
|
|
) as session:
|
|
|
|
|
async def fetch_image(image):
|
|
|
|
|
async with session.get(image) as response:
|
|
|
|
|
return base64.b64encode(await response.content.read()).decode()
|
|
|
|
|
images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()])
|
|
|
|
|
return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time()))
|
|
|
|
|
return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time()))
|
|
|
|
|
elif isinstance(chunk, ImageDataResponse):
|
|
|
|
|
return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time()))
|
|
|
|
|
|
|
|
|
|
def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
|
|
|
|
|
def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
|
|
|
|
|
prompt = f"create a image with: {prompt}"
|
|
|
|
|
if provider.__name__ == "You":
|
|
|
|
|
kwargs["chat_mode"] = "create"
|
|
|
|
@ -169,7 +190,6 @@ def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model
|
|
|
|
|
model,
|
|
|
|
|
[{"role": "user", "content": prompt}],
|
|
|
|
|
stream=True,
|
|
|
|
|
proxy=client.get_proxy(),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -179,31 +199,71 @@ class Images():
|
|
|
|
|
self.provider: ImageProvider = provider
|
|
|
|
|
self.models: ImageModels = ImageModels(client)
|
|
|
|
|
|
|
|
|
|
async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
|
|
|
|
|
provider = self.models.get(model, self.provider)
|
|
|
|
|
def get_provider(self, model: str, provider: ProviderType = None):
|
|
|
|
|
if isinstance(provider, str):
|
|
|
|
|
if provider in ProviderUtils.convert:
|
|
|
|
|
provider = ProviderUtils.convert[provider]
|
|
|
|
|
else:
|
|
|
|
|
raise ProviderNotFoundError(f'Provider not found: {provider}')
|
|
|
|
|
else:
|
|
|
|
|
provider = self.models.get(model, self.provider)
|
|
|
|
|
return provider
|
|
|
|
|
|
|
|
|
|
async def generate(
|
|
|
|
|
self,
|
|
|
|
|
prompt,
|
|
|
|
|
model: str = "",
|
|
|
|
|
provider: ProviderType = None,
|
|
|
|
|
response_format: str = None,
|
|
|
|
|
connector: BaseConnector = None,
|
|
|
|
|
proxy: str = None,
|
|
|
|
|
**kwargs
|
|
|
|
|
) -> ImagesResponse:
|
|
|
|
|
provider = self.get_provider(model, provider)
|
|
|
|
|
if hasattr(provider, "create_async_generator"):
|
|
|
|
|
response = create_image(self.client, provider, prompt, **kwargs)
|
|
|
|
|
response = create_image(
|
|
|
|
|
provider,
|
|
|
|
|
prompt,
|
|
|
|
|
**filter_none(
|
|
|
|
|
response_format=response_format,
|
|
|
|
|
connector=connector,
|
|
|
|
|
proxy=self.client.get_proxy() if proxy is None else proxy,
|
|
|
|
|
),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
response = await provider.create_async(prompt)
|
|
|
|
|
return ImagesResponse([Image(image) for image in response.get_list()])
|
|
|
|
|
image = await iter_image_response(response)
|
|
|
|
|
image = await iter_image_response(response, response_format, connector, proxy)
|
|
|
|
|
if image is None:
|
|
|
|
|
raise NoImageResponseError()
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
|
async def create_variation(self, image: ImageType, model: str = None, **kwargs):
|
|
|
|
|
provider = self.models.get(model, self.provider)
|
|
|
|
|
async def create_variation(
|
|
|
|
|
self,
|
|
|
|
|
image: ImageType,
|
|
|
|
|
model: str = None,
|
|
|
|
|
response_format: str = None,
|
|
|
|
|
connector: BaseConnector = None,
|
|
|
|
|
proxy: str = None,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
provider = self.get_provider(model, provider)
|
|
|
|
|
result = None
|
|
|
|
|
if hasattr(provider, "create_async_generator"):
|
|
|
|
|
response = provider.create_async_generator(
|
|
|
|
|
"",
|
|
|
|
|
[{"role": "user", "content": "create a image like this"}],
|
|
|
|
|
True,
|
|
|
|
|
stream=True,
|
|
|
|
|
image=image,
|
|
|
|
|
proxy=self.client.get_proxy(),
|
|
|
|
|
**filter_none(
|
|
|
|
|
response_format=response_format,
|
|
|
|
|
connector=connector,
|
|
|
|
|
proxy=self.client.get_proxy() if proxy is None else proxy,
|
|
|
|
|
),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
result = iter_image_response(response)
|
|
|
|
|
result = iter_image_response(response, response_format, connector, proxy)
|
|
|
|
|
if result is None:
|
|
|
|
|
raise NoImageResponseError()
|
|
|
|
|
return result
|
|
|
|
|