Add streaming and conversation support to gemini

pull/1995/head
Heiner Lohaus 4 weeks ago
parent 7eb41cfdcb
commit b7624b75a3

@ -1,3 +1,4 @@
from ..providers.base_provider import *
from ..providers.types import FinishReason, Streaming
from ..providers.conversation import BaseConversation
from .helper import get_cookies, format_prompt

@ -18,11 +18,11 @@ except ImportError:
from ... import debug
from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator
from ..base_provider import AsyncGeneratorProvider
from ..base_provider import AsyncGeneratorProvider, BaseConversation
from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError, MissingRequirementsError
from ...image import to_bytes, ImageResponse, ImageDataResponse
from ...image import ImageResponse, to_bytes
from ...webdriver import get_browser, get_driver_cookies
REQUEST_HEADERS = {
@ -32,7 +32,7 @@ REQUEST_HEADERS = {
'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
'x-same-domain': '1',
}
REQUEST_BL_PARAM = "boq_assistant-bard-web-server_20240421.18_p0"
REQUEST_BL_PARAM = "boq_assistant-bard-web-server_20240519.16_p0"
REQUEST_URL = "https://gemini.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate"
UPLOAD_IMAGE_URL = "https://content-push.googleapis.com/upload/"
UPLOAD_IMAGE_HEADERS = {
@ -57,6 +57,8 @@ class Gemini(AsyncGeneratorProvider):
image_models = ["gemini"]
default_vision_model = "gemini"
_cookies: Cookies = None
_snlm0e: str = None
_sid: str = None
@classmethod
async def nodriver_login(cls, proxy: str = None) -> AsyncIterator[str]:
@ -117,42 +119,40 @@ class Gemini(AsyncGeneratorProvider):
model: str,
messages: Messages,
proxy: str = None,
api_key: str = None,
cookies: Cookies = None,
connector: BaseConnector = None,
image: ImageType = None,
image_name: str = None,
response_format: str = None,
return_conversation: bool = False,
conversation: Conversation = None,
language: str = "en",
**kwargs
) -> AsyncResult:
prompt = format_prompt(messages)
if api_key is not None:
if cookies is None:
cookies = {}
cookies["__Secure-1PSID"] = api_key
prompt = format_prompt(messages) if conversation is None else messages[-1]["content"]
cls._cookies = cookies or cls._cookies or get_cookies(".google.com", False, True)
base_connector = get_connector(connector, proxy)
async with ClientSession(
headers=REQUEST_HEADERS,
connector=base_connector
) as session:
snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
if not snlm0e:
if not cls._snlm0e:
await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
if not cls._snlm0e:
async for chunk in cls.nodriver_login(proxy):
yield chunk
if cls._cookies is None:
async for chunk in cls.webdriver_login(proxy):
yield chunk
if not snlm0e:
if not cls._snlm0e:
if cls._cookies is None or "__Secure-1PSID" not in cls._cookies:
raise MissingAuthError('Missing "__Secure-1PSID" cookie')
snlm0e = await cls.fetch_snlm0e(session, cls._cookies)
if not snlm0e:
await cls.fetch_snlm0e(session, cls._cookies)
if not cls._snlm0e:
raise RuntimeError("Invalid cookies. SNlM0e not found")
image_url = await cls.upload_image(base_connector, to_bytes(image), image_name) if image else None
async with ClientSession(
cookies=cls._cookies,
headers=REQUEST_HEADERS,
@ -160,13 +160,17 @@ class Gemini(AsyncGeneratorProvider):
) as client:
params = {
'bl': REQUEST_BL_PARAM,
'hl': language,
'_reqid': random.randint(1111, 9999),
'rt': 'c'
'rt': 'c',
"f.sid": cls._sid,
}
data = {
'at': snlm0e,
'at': cls._snlm0e,
'f.req': json.dumps([None, json.dumps(cls.build_request(
prompt,
language=language,
conversation=conversation,
image_url=image_url,
image_name=image_name
))])
@ -177,19 +181,33 @@ class Gemini(AsyncGeneratorProvider):
params=params,
) as response:
await raise_for_status(response)
response = await response.text()
response_part = json.loads(json.loads(response.splitlines()[-5])[0][2])
if response_part[4] is None:
response_part = json.loads(json.loads(response.splitlines()[-7])[0][2])
content = response_part[4][0][1][0]
image_prompt = None
match = re.search(r'\[Imagen of (.*?)\]', content)
if match:
image_prompt = match.group(1)
content = content.replace(match.group(0), '')
yield content
image_prompt = response_part = None
last_content_len = 0
async for line in response.content:
try:
try:
line = json.loads(line)
except ValueError:
continue
if not isinstance(line, list):
continue
if len(line[0]) < 3 or not line[0][2]:
continue
response_part = json.loads(line[0][2])
if not response_part[4]:
continue
if return_conversation:
yield Conversation(response_part[1][0], response_part[1][1], response_part[4][0][0])
content = response_part[4][0][1][0]
except (ValueError, KeyError, TypeError, IndexError) as e:
print(f"{cls.__name__}:{e.__class__.__name__}:{e}")
continue
match = re.search(r'\[Imagen of (.*?)\]', content)
if match:
image_prompt = match.group(1)
content = content.replace(match.group(0), '')
yield content[last_content_len:]
last_content_len = len(content)
if image_prompt:
images = [image[0][3][3] for image in response_part[4][0][12][7][0]]
if response_format == "b64_json":
@ -208,9 +226,8 @@ class Gemini(AsyncGeneratorProvider):
def build_request(
prompt: str,
conversation_id: str = "",
response_id: str = "",
choice_id: str = "",
language: str,
conversation: Conversation = None,
image_url: str = None,
image_name: str = None,
tools: list[list[str]] = []
@ -218,8 +235,15 @@ class Gemini(AsyncGeneratorProvider):
image_list = [[[image_url, 1], image_name]] if image_url else []
return [
[prompt, 0, None, image_list, None, None, 0],
["en"],
[conversation_id, response_id, choice_id, None, None, []],
[language],
[
None if conversation is None else conversation.conversation_id,
None if conversation is None else conversation.response_id,
None if conversation is None else conversation.choice_id,
None,
None,
[]
],
None,
None,
None,
@ -265,7 +289,20 @@ class Gemini(AsyncGeneratorProvider):
async def fetch_snlm0e(cls, session: ClientSession, cookies: Cookies):
async with session.get(cls.url, cookies=cookies) as response:
await raise_for_status(response)
text = await response.text()
match = re.search(r'SNlM0e\":\"(.*?)\"', text)
response_text = await response.text()
match = re.search(r'SNlM0e\":\"(.*?)\"', response_text)
if match:
return match.group(1)
cls._snlm0e = match.group(1)
sid_match = re.search(r'"FdrFJe":"([\d-]+)"', response_text)
if sid_match:
cls._sid = sid_match.group(1)
class Conversation(BaseConversation):
def __init__(self,
conversation_id: str = "",
response_id: str = "",
choice_id: str = ""
) -> None:
self.conversation_id = conversation_id
self.response_id = response_id
self.choice_id = choice_id

@ -18,6 +18,7 @@ from ..Provider import ProviderUtils
from ..typing import Union, Messages, AsyncIterator, ImageType
from ..errors import NoImageResponseError, ProviderNotFoundError
from ..requests.aiohttp import get_connector
from ..providers.conversation import BaseConversation
from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse
try:
@ -42,6 +43,9 @@ async def iter_response(
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
elif isinstance(chunk, BaseConversation):
yield chunk
continue
content += str(chunk)
count += 1
if max_tokens is not None and count >= max_tokens:

@ -6,6 +6,7 @@ import string
from ..typing import Union, Iterator, Messages, ImageType
from ..providers.types import BaseProvider, ProviderType, FinishReason
from ..providers.conversation import BaseConversation
from ..image import ImageResponse as ImageProviderResponse
from ..errors import NoImageResponseError
from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
@ -29,6 +30,9 @@ def iter_response(
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
elif isinstance(chunk, BaseConversation):
yield chunk
continue
content += str(chunk)
if max_tokens is not None and idx + 1 >= max_tokens:
finish_reason = "length"

Loading…
Cancel
Save