Merge pull request #1509 from hlohaus/sort

Add ProviderModelMixin for model selection
pull/1512/head
H Lohaus 8 months ago committed by GitHub
commit 2b140a3255
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,18 +1,27 @@
from __future__ import annotations from __future__ import annotations
import json import json
import requests
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..requests import StreamSession from ..requests import StreamSession
class DeepInfra(AsyncGeneratorProvider): class DeepInfra(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://deepinfra.com" url = "https://deepinfra.com"
working = True working = True
supports_stream = True supports_stream = True
supports_message_history = True supports_message_history = True
default_model = 'meta-llama/Llama-2-70b-chat-hf'
@staticmethod @staticmethod
def get_models():
url = 'https://api.deepinfra.com/models/featured'
models = requests.get(url).json()
return [model['model_name'] for model in models]
@classmethod
async def create_async_generator( async def create_async_generator(
cls,
model: str, model: str,
messages: Messages, messages: Messages,
stream: bool, stream: bool,
@ -21,8 +30,6 @@ class DeepInfra(AsyncGeneratorProvider):
auth: str = None, auth: str = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if not model:
model = 'meta-llama/Llama-2-70b-chat-hf'
headers = { headers = {
'Accept-Encoding': 'gzip, deflate, br', 'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US', 'Accept-Language': 'en-US',
@ -49,7 +56,7 @@ class DeepInfra(AsyncGeneratorProvider):
impersonate="chrome110" impersonate="chrome110"
) as session: ) as session:
json_data = { json_data = {
'model' : model, 'model' : cls.get_model(model),
'messages': messages, 'messages': messages,
'stream' : True 'stream' : True
} }
@ -70,6 +77,7 @@ class DeepInfra(AsyncGeneratorProvider):
if token: if token:
if first: if first:
token = token.lstrip() token = token.lstrip()
if token:
first = False first = False
yield token yield token
except Exception: except Exception:

@ -5,11 +5,11 @@ import json, uuid
from aiohttp import ClientSession from aiohttp import ClientSession
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_prompt, get_cookies from .helper import format_prompt, get_cookies
class HuggingChat(AsyncGeneratorProvider): class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://huggingface.co/chat" url = "https://huggingface.co/chat"
working = True working = True
default_model = "meta-llama/Llama-2-70b-chat-hf" default_model = "meta-llama/Llama-2-70b-chat-hf"
@ -21,7 +21,7 @@ class HuggingChat(AsyncGeneratorProvider):
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.2",
"openchat/openchat-3.5-0106" "openchat/openchat-3.5-0106"
] ]
model_map = { model_aliases = {
"openchat/openchat_3.5": "openchat/openchat-3.5-1210", "openchat/openchat_3.5": "openchat/openchat-3.5-1210",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mistral-7B-Instruct-v0.2" "mistralai/Mixtral-8x7B-Instruct-v0.1": "mistralai/Mistral-7B-Instruct-v0.2"
} }
@ -37,12 +37,6 @@ class HuggingChat(AsyncGeneratorProvider):
cookies: dict = None, cookies: dict = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if not model:
model = cls.default_model
elif model in cls.model_map:
model = cls.model_map[model]
elif model not in cls.models:
raise ValueError(f"Model is not supported: {model}")
if not cookies: if not cookies:
cookies = get_cookies(".huggingface.co") cookies = get_cookies(".huggingface.co")
@ -53,7 +47,7 @@ class HuggingChat(AsyncGeneratorProvider):
cookies=cookies, cookies=cookies,
headers=headers headers=headers
) as session: ) as session:
async with session.post(f"{cls.url}/conversation", json={"model": model}, proxy=proxy) as response: async with session.post(f"{cls.url}/conversation", json={"model": cls.get_model(model)}, proxy=proxy) as response:
conversation_id = (await response.json())["conversationId"] conversation_id = (await response.json())["conversationId"]
send = { send = {

@ -5,7 +5,7 @@ import uuid
from aiohttp import ClientSession from aiohttp import ClientSession
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
models = { models = {
"gpt-4": { "gpt-4": {
@ -70,13 +70,17 @@ models = {
} }
} }
class Liaobots(AsyncGeneratorProvider, ProviderModelMixin):
class Liaobots(AsyncGeneratorProvider):
url = "https://liaobots.site" url = "https://liaobots.site"
working = True working = True
supports_message_history = True supports_message_history = True
supports_gpt_35_turbo = True supports_gpt_35_turbo = True
supports_gpt_4 = True supports_gpt_4 = True
default_model = "gpt-3.5-turbo"
models = [m for m in models]
model_aliases = {
"claude-v2": "claude-2"
}
_auth_code = None _auth_code = None
_cookie_jar = None _cookie_jar = None
@ -89,7 +93,6 @@ class Liaobots(AsyncGeneratorProvider):
proxy: str = None, proxy: str = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
model = model if model in models else "gpt-3.5-turbo"
headers = { headers = {
"authority": "liaobots.com", "authority": "liaobots.com",
"content-type": "application/json", "content-type": "application/json",
@ -122,7 +125,7 @@ class Liaobots(AsyncGeneratorProvider):
data = { data = {
"conversationId": str(uuid.uuid4()), "conversationId": str(uuid.uuid4()),
"model": models[model], "model": models[cls.get_model(model)],
"messages": messages, "messages": messages,
"key": "", "key": "",
"prompt": kwargs.get("system_message", "You are ChatGPT, a large language model trained by OpenAI. Follow the user's instructions carefully."), "prompt": kwargs.get("system_message", "You are ChatGPT, a large language model trained by OpenAI. Follow the user's instructions carefully."),

@ -3,18 +3,24 @@ from __future__ import annotations
from aiohttp import ClientSession from aiohttp import ClientSession
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
models = {
"meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat",
"meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat",
"meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat",
}
class Llama2(AsyncGeneratorProvider): class Llama2(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://www.llama2.ai" url = "https://www.llama2.ai"
working = True working = True
supports_message_history = True supports_message_history = True
default_model = "meta/llama-2-70b-chat"
models = [
"meta/llama-2-7b-chat",
"meta/llama-2-13b-chat",
"meta/llama-2-70b-chat",
]
model_aliases = {
"meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat",
"meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat",
"meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat",
}
@classmethod @classmethod
async def create_async_generator( async def create_async_generator(
@ -24,10 +30,6 @@ class Llama2(AsyncGeneratorProvider):
proxy: str = None, proxy: str = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if not model:
model = "meta/llama-2-70b-chat"
elif model in models:
model = models[model]
headers = { headers = {
"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/118.0", "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/118.0",
"Accept": "*/*", "Accept": "*/*",
@ -48,7 +50,7 @@ class Llama2(AsyncGeneratorProvider):
prompt = format_prompt(messages) prompt = format_prompt(messages)
data = { data = {
"prompt": prompt, "prompt": prompt,
"model": model, "model": cls.get_model(model),
"systemPrompt": kwargs.get("system_message", "You are a helpful assistant."), "systemPrompt": kwargs.get("system_message", "You are a helpful assistant."),
"temperature": kwargs.get("temperature", 0.75), "temperature": kwargs.get("temperature", 0.75),
"topP": kwargs.get("top_p", 0.9), "topP": kwargs.get("top_p", 0.9),

@ -5,20 +5,21 @@ import json
from aiohttp import ClientSession from aiohttp import ClientSession
from ..typing import AsyncResult, Messages from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
API_URL = "https://labs-api.perplexity.ai/socket.io/" API_URL = "https://labs-api.perplexity.ai/socket.io/"
WS_URL = "wss://labs-api.perplexity.ai/socket.io/" WS_URL = "wss://labs-api.perplexity.ai/socket.io/"
class PerplexityLabs(AsyncGeneratorProvider): class PerplexityLabs(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://labs.perplexity.ai" url = "https://labs.perplexity.ai"
working = True working = True
supports_gpt_35_turbo = True models = [
models = ['pplx-7b-online', 'pplx-70b-online', 'pplx-7b-chat', 'pplx-70b-chat', 'mistral-7b-instruct', 'pplx-7b-online', 'pplx-70b-online', 'pplx-7b-chat', 'pplx-70b-chat', 'mistral-7b-instruct',
'codellama-34b-instruct', 'llama-2-70b-chat', 'llava-7b-chat', 'mixtral-8x7b-instruct', 'codellama-34b-instruct', 'llama-2-70b-chat', 'llava-7b-chat', 'mixtral-8x7b-instruct',
'mistral-medium', 'related'] 'mistral-medium', 'related'
]
default_model = 'pplx-70b-online' default_model = 'pplx-70b-online'
model_map = { model_aliases = {
"mistralai/Mistral-7B-Instruct-v0.1": "mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.1": "mistral-7b-instruct",
"meta-llama/Llama-2-70b-chat-hf": "llama-2-70b-chat", "meta-llama/Llama-2-70b-chat-hf": "llama-2-70b-chat",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral-8x7b-instruct", "mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral-8x7b-instruct",
@ -33,12 +34,6 @@ class PerplexityLabs(AsyncGeneratorProvider):
proxy: str = None, proxy: str = None,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if not model:
model = cls.default_model
elif model in cls.model_map:
model = cls.model_map[model]
elif model not in cls.models:
raise ValueError(f"Model is not supported: {model}")
headers = { headers = {
"User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:121.0) Gecko/20100101 Firefox/121.0", "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:121.0) Gecko/20100101 Firefox/121.0",
"Accept": "*/*", "Accept": "*/*",
@ -78,7 +73,7 @@ class PerplexityLabs(AsyncGeneratorProvider):
message_data = { message_data = {
'version': '2.2', 'version': '2.2',
'source': 'default', 'source': 'default',
'model': model, 'model': cls.get_model(model),
'messages': messages 'messages': messages
} }
await ws.send_str('42' + json.dumps(['perplexity_playground', message_data])) await ws.send_str('42' + json.dumps(['perplexity_playground', message_data]))

@ -8,7 +8,7 @@ from inspect import signature, Parameter
from .helper import get_cookies, format_prompt from .helper import get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages, Union from ..typing import CreateResult, AsyncResult, Messages, Union
from ..base_provider import BaseProvider from ..base_provider import BaseProvider
from ..errors import NestAsyncioError from ..errors import NestAsyncioError, ModelNotSupportedError
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
NoneType = type(None) NoneType = type(None)
@ -252,3 +252,22 @@ class AsyncGeneratorProvider(AsyncProvider):
AsyncResult: An asynchronous generator yielding results. AsyncResult: An asynchronous generator yielding results.
""" """
raise NotImplementedError() raise NotImplementedError()
class ProviderModelMixin:
default_model: str
models: list[str] = []
model_aliases: dict[str, str] = {}
@classmethod
def get_models(cls) -> list[str]:
return cls.models
@classmethod
def get_model(cls, model: str) -> str:
if not model:
return cls.default_model
elif model in cls.model_aliases:
return cls.model_aliases[model]
elif model not in cls.get_models():
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
return model

@ -10,22 +10,15 @@ from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support import expected_conditions as EC
from ..base_provider import AsyncGeneratorProvider from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import format_prompt, get_cookies from ..helper import format_prompt, get_cookies
from ...webdriver import get_browser, get_driver_cookies from ...webdriver import get_browser, get_driver_cookies
from ...typing import AsyncResult, Messages from ...typing import AsyncResult, Messages
from ...requests import StreamSession from ...requests import StreamSession
from ...image import to_image, to_bytes, ImageType, ImageResponse from ...image import to_image, to_bytes, ImageType, ImageResponse
# Aliases for model names
MODELS = {
"gpt-3.5": "text-davinci-002-render-sha",
"gpt-3.5-turbo": "text-davinci-002-render-sha",
"gpt-4": "gpt-4",
"gpt-4-gizmo": "gpt-4-gizmo"
}
class OpenaiChat(AsyncGeneratorProvider): class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"""A class for creating and managing conversations with OpenAI chat service""" """A class for creating and managing conversations with OpenAI chat service"""
url = "https://chat.openai.com" url = "https://chat.openai.com"
@ -33,6 +26,11 @@ class OpenaiChat(AsyncGeneratorProvider):
needs_auth = True needs_auth = True
supports_gpt_35_turbo = True supports_gpt_35_turbo = True
supports_gpt_4 = True supports_gpt_4 = True
default_model = None
models = ["text-davinci-002-render-sha", "gpt-4", "gpt-4-gizmo"]
model_aliases = {
"gpt-3.5-turbo": "text-davinci-002-render-sha",
}
_cookies: dict = {} _cookies: dict = {}
_default_model: str = None _default_model: str = None
@ -91,7 +89,7 @@ class OpenaiChat(AsyncGeneratorProvider):
) )
@classmethod @classmethod
async def _upload_image( async def upload_image(
cls, cls,
session: StreamSession, session: StreamSession,
headers: dict, headers: dict,
@ -150,7 +148,7 @@ class OpenaiChat(AsyncGeneratorProvider):
return ImageResponse(download_url, image_data["file_name"], image_data) return ImageResponse(download_url, image_data["file_name"], image_data)
@classmethod @classmethod
async def _get_default_model(cls, session: StreamSession, headers: dict): async def get_default_model(cls, session: StreamSession, headers: dict):
""" """
Get the default model name from the service Get the default model name from the service
@ -161,20 +159,17 @@ class OpenaiChat(AsyncGeneratorProvider):
Returns: Returns:
The default model name as a string The default model name as a string
""" """
# Check the cache for the default model if not cls.default_model:
if cls._default_model:
return cls._default_model
# Get the models data from the service
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
data = await response.json() data = await response.json()
if "categories" in data: if "categories" in data:
cls._default_model = data["categories"][-1]["default_model"] cls.default_model = data["categories"][-1]["default_model"]
else: else:
raise RuntimeError(f"Response: {data}") raise RuntimeError(f"Response: {data}")
return cls._default_model return cls.default_model
@classmethod @classmethod
def _create_messages(cls, prompt: str, image_response: ImageResponse = None): def create_messages(cls, prompt: str, image_response: ImageResponse = None):
""" """
Create a list of messages for the user input Create a list of messages for the user input
@ -222,7 +217,7 @@ class OpenaiChat(AsyncGeneratorProvider):
return messages return messages
@classmethod @classmethod
async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse: async def get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
""" """
Retrieves the image response based on the message content. Retrieves the image response based on the message content.
@ -257,7 +252,7 @@ class OpenaiChat(AsyncGeneratorProvider):
raise RuntimeError(f"Error in downloading image: {e}") raise RuntimeError(f"Error in downloading image: {e}")
@classmethod @classmethod
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str): async def delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
""" """
Deletes a conversation by setting its visibility to False. Deletes a conversation by setting its visibility to False.
@ -322,7 +317,6 @@ class OpenaiChat(AsyncGeneratorProvider):
Raises: Raises:
RuntimeError: If an error occurs during processing. RuntimeError: If an error occurs during processing.
""" """
model = MODELS.get(model, model)
if not parent_id: if not parent_id:
parent_id = str(uuid.uuid4()) parent_id = str(uuid.uuid4())
if not cookies: if not cookies:
@ -333,7 +327,7 @@ class OpenaiChat(AsyncGeneratorProvider):
login_url = os.environ.get("G4F_LOGIN_URL") login_url = os.environ.get("G4F_LOGIN_URL")
if login_url: if login_url:
yield f"Please login: [ChatGPT]({login_url})\n\n" yield f"Please login: [ChatGPT]({login_url})\n\n"
access_token, cookies = cls._browse_access_token(proxy) access_token, cookies = cls.browse_access_token(proxy)
cls._cookies = cookies cls._cookies = cookies
headers = {"Authorization": f"Bearer {access_token}"} headers = {"Authorization": f"Bearer {access_token}"}
@ -344,12 +338,10 @@ class OpenaiChat(AsyncGeneratorProvider):
timeout=timeout, timeout=timeout,
cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"]) cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"])
) as session: ) as session:
if not model:
model = await cls._get_default_model(session, headers)
try: try:
image_response = None image_response = None
if image: if image:
image_response = await cls._upload_image(session, headers, image) image_response = await cls.upload_image(session, headers, image)
yield image_response yield image_response
except Exception as e: except Exception as e:
yield e yield e
@ -357,15 +349,15 @@ class OpenaiChat(AsyncGeneratorProvider):
while not end_turn.is_end: while not end_turn.is_end:
data = { data = {
"action": action, "action": action,
"arkose_token": await cls._get_arkose_token(session), "arkose_token": await cls.get_arkose_token(session),
"conversation_id": conversation_id, "conversation_id": conversation_id,
"parent_message_id": parent_id, "parent_message_id": parent_id,
"model": model, "model": cls.get_model(model or await cls.get_default_model(session, headers)),
"history_and_training_disabled": history_disabled and not auto_continue, "history_and_training_disabled": history_disabled and not auto_continue,
} }
if action != "continue": if action != "continue":
prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
data["messages"] = cls._create_messages(prompt, image_response) data["messages"] = cls.create_messages(prompt, image_response)
async with session.post( async with session.post(
f"{cls.url}/backend-api/conversation", f"{cls.url}/backend-api/conversation",
json=data, json=data,
@ -391,7 +383,7 @@ class OpenaiChat(AsyncGeneratorProvider):
if "message_type" not in line["message"]["metadata"]: if "message_type" not in line["message"]["metadata"]:
continue continue
try: try:
image_response = await cls._get_generated_image(session, headers, line) image_response = await cls.get_generated_image(session, headers, line)
if image_response: if image_response:
yield image_response yield image_response
except Exception as e: except Exception as e:
@ -422,10 +414,10 @@ class OpenaiChat(AsyncGeneratorProvider):
action = "continue" action = "continue"
await asyncio.sleep(5) await asyncio.sleep(5)
if history_disabled and auto_continue: if history_disabled and auto_continue:
await cls._delete_conversation(session, headers, conversation_id) await cls.delete_conversation(session, headers, conversation_id)
@classmethod @classmethod
def _browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]: def browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]:
""" """
Browse to obtain an access token. Browse to obtain an access token.
@ -452,7 +444,7 @@ class OpenaiChat(AsyncGeneratorProvider):
driver.quit() driver.quit()
@classmethod @classmethod
async def _get_arkose_token(cls, session: StreamSession) -> str: async def get_arkose_token(cls, session: StreamSession) -> str:
""" """
Obtain an Arkose token for the session. Obtain an Arkose token for the session.

@ -27,3 +27,6 @@ class VersionNotFoundError(Exception):
class NestAsyncioError(Exception): class NestAsyncioError(Exception):
pass pass
class ModelNotSupportedError(Exception):
pass

@ -20,23 +20,23 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image.Image:
try: try:
import cairosvg import cairosvg
except ImportError: except ImportError:
raise RuntimeError('Install "cairosvg" package for open svg images') raise RuntimeError('Install "cairosvg" package for svg images')
if not isinstance(image, bytes): if not isinstance(image, bytes):
image = image.read() image = image.read()
buffer = BytesIO() buffer = BytesIO()
cairosvg.svg2png(image, write_to=buffer) cairosvg.svg2png(image, write_to=buffer)
image = Image.open(buffer) return Image.open(buffer)
if isinstance(image, str): if isinstance(image, str):
is_data_uri_an_image(image) is_data_uri_an_image(image)
image = extract_data_uri(image) image = extract_data_uri(image)
if isinstance(image, bytes): if isinstance(image, bytes):
is_accepted_format(image) is_accepted_format(image)
image = Image.open(BytesIO(image)) return Image.open(BytesIO(image))
elif not isinstance(image, Image.Image): elif not isinstance(image, Image.Image):
image = Image.open(image) image = Image.open(image)
copy = image.copy() copy = image.copy()
copy.format = image.format copy.format = image.format
image = copy return copy
return image return image
def is_allowed_extension(filename: str) -> bool: def is_allowed_extension(filename: str) -> bool:
@ -138,6 +138,7 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
Returns: Returns:
Image.Image: The processed image. Image.Image: The processed image.
""" """
# Fix orientation
orientation = get_orientation(img) orientation = get_orientation(img)
if orientation: if orientation:
if orientation > 4: if orientation > 4:
@ -148,7 +149,14 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
img = img.transpose(Image.ROTATE_270) img = img.transpose(Image.ROTATE_270)
if orientation in [7, 8]: if orientation in [7, 8]:
img = img.transpose(Image.ROTATE_90) img = img.transpose(Image.ROTATE_90)
# Resize image
img.thumbnail((new_width, new_height)) img.thumbnail((new_width, new_height))
# Remove transparency
if img.mode != "RGB":
img.load()
white = Image.new('RGB', img.size, (255, 255, 255))
white.paste(img, mask=img.split()[3])
return white
return img return img
def to_base64(image: Image.Image, compression_rate: float) -> str: def to_base64(image: Image.Image, compression_rate: float) -> str:
@ -163,8 +171,6 @@ def to_base64(image: Image.Image, compression_rate: float) -> str:
str: The base64-encoded image. str: The base64-encoded image.
""" """
output_buffer = BytesIO() output_buffer = BytesIO()
if image.mode != "RGB":
image = image.convert('RGB')
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode() return base64.b64encode(output_buffer.getvalue()).decode()

Loading…
Cancel
Save