Merge pull request #1633 from hlohaus/flow

Fix GeminiPro auth for normal user
Add rdns suport for proxies
Improve filter_messages in gui
This commit is contained in:
H Lohaus 2024-02-26 11:30:17 +01:00 committed by GitHub
commit 36e7665613
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 19 deletions

View File

@ -70,7 +70,7 @@ class Bing(AsyncGeneratorProvider):
gpt4_turbo = True if model.startswith("gpt-4-turbo") else False gpt4_turbo = True if model.startswith("gpt-4-turbo") else False
return stream_generate(prompt, tone, image, context, cookies, get_connector(connector, proxy), web_search, gpt4_turbo, timeout) return stream_generate(prompt, tone, image, context, cookies, get_connector(connector, proxy, True), web_search, gpt4_turbo, timeout)
def create_context(messages: Messages) -> str: def create_context(messages: Messages) -> str:
""" """

View File

@ -2,12 +2,13 @@ from __future__ import annotations
import base64 import base64
import json import json
from aiohttp import ClientSession from aiohttp import ClientSession, BaseConnector
from ..typing import AsyncResult, Messages, ImageType from ..typing import AsyncResult, Messages, ImageType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import to_bytes, is_accepted_format from ..image import to_bytes, is_accepted_format
from ..errors import MissingAuthError from ..errors import MissingAuthError
from .helper import get_connector
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://ai.google.dev" url = "https://ai.google.dev"
@ -27,6 +28,7 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
api_key: str = None, api_key: str = None,
api_base: str = None, api_base: str = None,
image: ImageType = None, image: ImageType = None,
connector: BaseConnector = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
model = "gemini-pro-vision" if not model and image else model model = "gemini-pro-vision" if not model and image else model
@ -34,18 +36,19 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
if not api_key: if not api_key:
raise MissingAuthError('Missing "api_key"') raise MissingAuthError('Missing "api_key"')
headers = params = None
if api_base:
headers = {"Authorization": f"Bearer {api_key}"}
else:
params = {"key": api_key}
if not api_base: if not api_base:
api_base = f"https://generativelanguage.googleapis.com/v1beta" api_base = f"https://generativelanguage.googleapis.com/v1beta"
method = "streamGenerateContent" if stream else "generateContent" method = "streamGenerateContent" if stream else "generateContent"
url = f"{api_base.rstrip('/')}/models/{model}:{method}" url = f"{api_base.rstrip('/')}/models/{model}:{method}"
headers = None async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
if api_base:
headers = {f"Authorization": "Bearer {api_key}"}
else:
url += f"?key={api_key}"
async with ClientSession(headers=headers) as session:
contents = [ contents = [
{ {
"role": "model" if message["role"] == "assistant" else message["role"], "role": "model" if message["role"] == "assistant" else message["role"],
@ -71,10 +74,11 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
"topK": kwargs.get("top_k"), "topK": kwargs.get("top_k"),
} }
} }
async with session.post(url, json=data, proxy=proxy) as response: async with session.post(url, params=params, json=data) as response:
if not response.ok: if not response.ok:
data = await response.json() data = await response.json()
raise RuntimeError(data[0]["error"]["message"]) data = data[0] if isinstance(data, list) else data
raise RuntimeError(data["error"]["message"])
if stream: if stream:
lines = [] lines = []
async for chunk in response.content: async for chunk in response.content:

View File

@ -105,7 +105,7 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin):
async with ClientSession( async with ClientSession(
headers=headers, headers=headers,
cookie_jar=cls._cookie_jar, cookie_jar=cls._cookie_jar,
connector=get_connector(connector, proxy) connector=get_connector(connector, proxy, True)
) as session: ) as session:
cls._auth_code = auth if isinstance(auth, str) else cls._auth_code cls._auth_code = auth if isinstance(auth, str) else cls._auth_code
if not cls._auth_code: if not cls._auth_code:

View File

@ -173,7 +173,6 @@
<option value="">Provider: Auto</option> <option value="">Provider: Auto</option>
<option value="Bing">Bing</option> <option value="Bing">Bing</option>
<option value="OpenaiChat">OpenaiChat</option> <option value="OpenaiChat">OpenaiChat</option>
<option value="HuggingChat">HuggingChat</option>
<option value="Gemini">Gemini</option> <option value="Gemini">Gemini</option>
<option value="Liaobots">Liaobots</option> <option value="Liaobots">Liaobots</option>
<option value="Phind">Phind</option> <option value="Phind">Phind</option>

View File

@ -121,6 +121,20 @@ const remove_cancel_button = async () => {
}; };
const filter_messages = (messages) => { const filter_messages = (messages) => {
// Removes none user messages at end
let last_message;
while (last_message = new_messages.pop()) {
if (last_message["role"] == "user") {
new_messages.push(last_message);
break;
}
}
// Remove history, if it is selected
if (document.getElementById('history')?.checked) {
messages = [messages[messages.length-1]];
}
let new_messages = []; let new_messages = [];
for (i in messages) { for (i in messages) {
new_message = messages[i]; new_message = messages[i];
@ -135,6 +149,7 @@ const filter_messages = (messages) => {
new_messages.push(new_message) new_messages.push(new_message)
} }
} }
return new_messages; return new_messages;
} }
@ -143,10 +158,6 @@ const ask_gpt = async () => {
messages = await get_messages(window.conversation_id); messages = await get_messages(window.conversation_id);
total_messages = messages.length; total_messages = messages.length;
// Remove history, if it is selected
if (document.getElementById('history')?.checked) {
messages = [messages[messages.length-1]];
}
messages = filter_messages(messages); messages = filter_messages(messages);
window.scrollTo(0, 0); window.scrollTo(0, 0);

View File

@ -51,11 +51,14 @@ def get_random_hex() -> str:
""" """
return secrets.token_hex(16).zfill(32) return secrets.token_hex(16).zfill(32)
def get_connector(connector: BaseConnector = None, proxy: str = None) -> Optional[BaseConnector]: def get_connector(connector: BaseConnector = None, proxy: str = None, rdns: bool = False) -> Optional[BaseConnector]:
if proxy and not connector: if proxy and not connector:
try: try:
from aiohttp_socks import ProxyConnector from aiohttp_socks import ProxyConnector
connector = ProxyConnector.from_url(proxy) if proxy.startswith("socks5h://"):
proxy = proxy.replace("socks5h://", "socks5://")
rdns = True
connector = ProxyConnector.from_url(proxy, rdns=rdns)
except ImportError: except ImportError:
raise MissingRequirementsError('Install "aiohttp_socks" package for proxy support') raise MissingRequirementsError('Install "aiohttp_socks" package for proxy support')
return connector return connector