mirror of
https://github.com/xtekky/gpt4free.git
synced 2024-11-10 19:11:01 +00:00
Improve tests
This commit is contained in:
parent
9cf2ee0279
commit
9cbe9c1ccb
4
.github/workflows/copilot.yml
vendored
4
.github/workflows/copilot.yml
vendored
@ -9,7 +9,9 @@ on:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
permissions: write-all
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
review:
|
||||
|
2
.github/workflows/unittest.yml
vendored
2
.github/workflows/unittest.yml
vendored
@ -16,4 +16,4 @@ jobs:
|
||||
- name: Install requirements
|
||||
run: pip install -r requirements.txt
|
||||
- name: Run tests
|
||||
run: python -m etc.unittest.main
|
||||
run: python -m etc.unittest
|
6
etc/unittest/__main__.py
Normal file
6
etc/unittest/__main__.py
Normal file
@ -0,0 +1,6 @@
|
||||
import unittest
|
||||
from .asyncio import *
|
||||
from .backend import *
|
||||
from .main import *
|
||||
|
||||
unittest.main()
|
57
etc/unittest/asyncio.py
Normal file
57
etc/unittest/asyncio.py
Normal file
@ -0,0 +1,57 @@
|
||||
from .include import DEFAULT_MESSAGES
|
||||
import asyncio
|
||||
import nest_asyncio
|
||||
import unittest
|
||||
import g4f
|
||||
from g4f import ChatCompletion
|
||||
from .mocks import ProviderMock, AsyncProviderMock, AsyncGeneratorProviderMock
|
||||
|
||||
class TestChatCompletion(unittest.TestCase):
|
||||
|
||||
async def run_exception(self):
|
||||
return ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
|
||||
|
||||
def test_exception(self):
|
||||
self.assertRaises(g4f.errors.NestAsyncioError, asyncio.run, self.run_exception())
|
||||
|
||||
def test_create(self):
|
||||
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
def test_create_generator(self):
|
||||
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_base(self):
|
||||
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
async def test_async(self):
|
||||
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
async def test_create_generator(self):
|
||||
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
class TestChatCompletionNestAsync(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
nest_asyncio.apply()
|
||||
|
||||
async def test_create(self):
|
||||
result = await ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
async def test_nested(self):
|
||||
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
async def test_nested_generator(self):
|
||||
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock)
|
||||
self.assertEqual("Mock",result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
38
etc/unittest/backend.py
Normal file
38
etc/unittest/backend.py
Normal file
@ -0,0 +1,38 @@
|
||||
from . import include
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from .mocks import ProviderMock
|
||||
import g4f
|
||||
from g4f.gui.server.backend import Backend_Api, get_error_message
|
||||
|
||||
class TestBackendApi(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.app = MagicMock()
|
||||
self.api = Backend_Api(self.app)
|
||||
|
||||
def test_version(self):
|
||||
response = self.api.get_version()
|
||||
self.assertIn("version", response)
|
||||
self.assertIn("latest_version", response)
|
||||
|
||||
def test_get_models(self):
|
||||
response = self.api.get_models()
|
||||
self.assertIsInstance(response, list)
|
||||
self.assertTrue(len(response) > 0)
|
||||
|
||||
def test_get_providers(self):
|
||||
response = self.api.get_providers()
|
||||
self.assertIsInstance(response, list)
|
||||
self.assertTrue(len(response) > 0)
|
||||
|
||||
class TestUtilityFunctions(unittest.TestCase):
|
||||
|
||||
def test_get_error_message(self):
|
||||
g4f.debug.last_provider = ProviderMock
|
||||
exception = Exception("Message")
|
||||
result = get_error_message(exception)
|
||||
self.assertEqual("ProviderMock: Exception: Message", result)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
11
etc/unittest/include.py
Normal file
11
etc/unittest/include.py
Normal file
@ -0,0 +1,11 @@
|
||||
import sys
|
||||
import pathlib
|
||||
|
||||
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
||||
|
||||
import g4f
|
||||
|
||||
g4f.debug.logging = False
|
||||
g4f.debug.version_check = False
|
||||
|
||||
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
|
@ -1,75 +1,37 @@
|
||||
import sys
|
||||
import pathlib
|
||||
from .include import DEFAULT_MESSAGES
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.append(str(pathlib.Path(__file__).parent.parent.parent))
|
||||
|
||||
import asyncio
|
||||
import g4f
|
||||
from g4f import ChatCompletion, get_last_provider
|
||||
from g4f.gui.server.backend import Backend_Api, get_error_message
|
||||
from g4f.base_provider import BaseProvider
|
||||
|
||||
g4f.debug.logging = False
|
||||
g4f.debug.version_check = False
|
||||
|
||||
class MockProvider(BaseProvider):
|
||||
working = True
|
||||
|
||||
def create_completion(
|
||||
model, messages, stream, **kwargs
|
||||
):
|
||||
yield "Mock"
|
||||
|
||||
async def create_async(
|
||||
model, messages, **kwargs
|
||||
):
|
||||
return "Mock"
|
||||
|
||||
class TestBackendApi(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.app = MagicMock()
|
||||
self.api = Backend_Api(self.app)
|
||||
|
||||
def test_version(self):
|
||||
response = self.api.get_version()
|
||||
self.assertIn("version", response)
|
||||
self.assertIn("latest_version", response)
|
||||
from g4f.Provider import RetryProvider
|
||||
from .mocks import ProviderMock
|
||||
|
||||
class TestChatCompletion(unittest.TestCase):
|
||||
|
||||
def test_create_default(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
result = ChatCompletion.create(g4f.models.default, messages)
|
||||
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES)
|
||||
if "Good" not in result and "Hi" not in result:
|
||||
self.assertIn("Hello", result)
|
||||
|
||||
def test_get_last_provider(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
ChatCompletion.create(g4f.models.default, messages, MockProvider)
|
||||
self.assertEqual(get_last_provider(), MockProvider)
|
||||
|
||||
def test_bing_provider(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
provider = g4f.Provider.Bing
|
||||
result = ChatCompletion.create(g4f.models.default, messages, provider)
|
||||
result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, provider)
|
||||
self.assertIn("Bing", result)
|
||||
|
||||
class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase):
|
||||
class TestGetLastProvider(unittest.TestCase):
|
||||
|
||||
async def test_async(self):
|
||||
messages = [{'role': 'user', 'content': 'Hello'}]
|
||||
result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider)
|
||||
self.assertEqual("Mock", result)
|
||||
def test_get_last_provider(self):
|
||||
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
|
||||
self.assertEqual(get_last_provider(), ProviderMock)
|
||||
|
||||
class TestUtilityFunctions(unittest.TestCase):
|
||||
def test_get_last_provider_retry(self):
|
||||
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock]))
|
||||
self.assertEqual(get_last_provider(), ProviderMock)
|
||||
|
||||
def test_get_error_message(self):
|
||||
g4f.debug.last_provider = g4f.Provider.Bing
|
||||
exception = Exception("Message")
|
||||
result = get_error_message(exception)
|
||||
self.assertEqual("Bing: Exception: Message", result)
|
||||
def test_get_last_provider_async(self):
|
||||
coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
|
||||
asyncio.run(coroutine)
|
||||
self.assertEqual(get_last_provider(), ProviderMock)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
25
etc/unittest/mocks.py
Normal file
25
etc/unittest/mocks.py
Normal file
@ -0,0 +1,25 @@
|
||||
from g4f.Provider.base_provider import AbstractProvider, AsyncProvider, AsyncGeneratorProvider
|
||||
|
||||
class ProviderMock(AbstractProvider):
|
||||
working = True
|
||||
|
||||
def create_completion(
|
||||
model, messages, stream, **kwargs
|
||||
):
|
||||
yield "Mock"
|
||||
|
||||
class AsyncProviderMock(AsyncProvider):
|
||||
working = True
|
||||
|
||||
async def create_async(
|
||||
model, messages, **kwargs
|
||||
):
|
||||
return "Mock"
|
||||
|
||||
class AsyncGeneratorProviderMock(AsyncGeneratorProvider):
|
||||
working = True
|
||||
|
||||
async def create_async_generator(
|
||||
model, messages, stream, **kwargs
|
||||
):
|
||||
yield "Mock"
|
@ -64,12 +64,7 @@ class Bing(AsyncGeneratorProvider):
|
||||
prompt = messages[-1]["content"]
|
||||
context = create_context(messages[:-1])
|
||||
|
||||
if not cookies:
|
||||
cookies = Defaults.cookies
|
||||
else:
|
||||
for key, value in Defaults.cookies.items():
|
||||
if key not in cookies:
|
||||
cookies[key] = value
|
||||
cookies = {**Defaults.cookies, **cookies} if cookies else Defaults.cookies
|
||||
|
||||
gpt4_turbo = True if model.startswith("gpt-4-turbo") else False
|
||||
|
||||
@ -207,10 +202,12 @@ def create_message(
|
||||
request_id = str(uuid.uuid4())
|
||||
struct = {
|
||||
'arguments': [{
|
||||
'source': 'cib', 'optionsSets': options_sets,
|
||||
'source': 'cib',
|
||||
'optionsSets': options_sets,
|
||||
'allowedMessageTypes': Defaults.allowedMessageTypes,
|
||||
'sliceIds': Defaults.sliceIds,
|
||||
'traceId': os.urandom(16).hex(), 'isStartOfSession': True,
|
||||
'traceId': os.urandom(16).hex(),
|
||||
'isStartOfSession': True,
|
||||
'requestId': request_id,
|
||||
'message': {
|
||||
**Defaults.location,
|
||||
|
@ -5,8 +5,8 @@ from asyncio import AbstractEventLoop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from abc import abstractmethod
|
||||
from inspect import signature, Parameter
|
||||
from .helper import get_event_loop, get_cookies, format_prompt
|
||||
from ..typing import CreateResult, AsyncResult, Messages
|
||||
from .helper import get_cookies, format_prompt
|
||||
from ..typing import CreateResult, AsyncResult, Messages, Union
|
||||
from ..base_provider import BaseProvider
|
||||
from ..errors import NestAsyncioError
|
||||
|
||||
@ -20,6 +20,17 @@ if sys.platform == 'win32':
|
||||
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
def get_running_loop() -> Union[AbstractEventLoop, None]:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop.__class__, "_nest_patched"):
|
||||
raise NestAsyncioError(
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
|
||||
)
|
||||
return loop
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
class AbstractProvider(BaseProvider):
|
||||
"""
|
||||
Abstract class for providing asynchronous functionality to derived classes.
|
||||
@ -56,7 +67,7 @@ class AbstractProvider(BaseProvider):
|
||||
|
||||
return await asyncio.wait_for(
|
||||
loop.run_in_executor(executor, create_func),
|
||||
timeout=kwargs.get("timeout", 0)
|
||||
timeout=kwargs.get("timeout")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -118,14 +129,7 @@ class AsyncProvider(AbstractProvider):
|
||||
Returns:
|
||||
CreateResult: The result of the completion creation.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop.__class__, "_nest_patched"):
|
||||
raise NestAsyncioError(
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
|
||||
)
|
||||
except RuntimeError:
|
||||
pass
|
||||
get_running_loop()
|
||||
yield asyncio.run(cls.create_async(model, messages, **kwargs))
|
||||
|
||||
@staticmethod
|
||||
@ -180,15 +184,12 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
Returns:
|
||||
CreateResult: The result of the streaming completion creation.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if not hasattr(loop.__class__, "_nest_patched"):
|
||||
raise NestAsyncioError(
|
||||
'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.'
|
||||
)
|
||||
except RuntimeError:
|
||||
loop = get_running_loop()
|
||||
new_loop = False
|
||||
if not loop:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
new_loop = True
|
||||
|
||||
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
|
||||
gen = generator.__aiter__()
|
||||
@ -199,6 +200,10 @@ class AsyncGeneratorProvider(AsyncProvider):
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
if new_loop:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
@classmethod
|
||||
async def create_async(
|
||||
cls,
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import json
|
||||
from flask import request, Flask
|
||||
from typing import Generator
|
||||
from g4f import debug, version, models
|
||||
from g4f import version, models
|
||||
from g4f import _all_models, get_last_provider, ChatCompletion
|
||||
from g4f.image import is_allowed_extension, to_image
|
||||
from g4f.errors import VersionNotFoundError
|
||||
@ -10,7 +10,6 @@ from g4f.Provider import __providers__
|
||||
from g4f.Provider.bing.create_images import patch_provider
|
||||
from .internet import get_search_message
|
||||
|
||||
debug.logging = True
|
||||
|
||||
class Backend_Api:
|
||||
"""
|
||||
|
14
g4f/image.py
14
g4f/image.py
@ -112,7 +112,7 @@ def get_orientation(image: Image.Image) -> int:
|
||||
"""
|
||||
exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
|
||||
if exif_data is not None:
|
||||
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
|
||||
orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
|
||||
if orientation is not None:
|
||||
return orientation
|
||||
|
||||
@ -156,23 +156,23 @@ def to_base64(image: Image.Image, compression_rate: float) -> str:
|
||||
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
|
||||
return base64.b64encode(output_buffer.getvalue()).decode()
|
||||
|
||||
def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str:
|
||||
def format_images_markdown(images, alt: str, preview: str="{image}?w=200&h=200") -> str:
|
||||
"""
|
||||
Formats the given images as a markdown string.
|
||||
|
||||
Args:
|
||||
images: The images to format.
|
||||
prompt (str): The prompt for the images.
|
||||
alt (str): The alt for the images.
|
||||
preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown string.
|
||||
"""
|
||||
if isinstance(images, list):
|
||||
images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
|
||||
images = "\n".join(images)
|
||||
if isinstance(images, str):
|
||||
images = f"[![{alt}]({preview.replace('{image}', images)})]({images})"
|
||||
else:
|
||||
images = f"[![{prompt}]({images})]({images})"
|
||||
images = [f"[![#{idx+1} {alt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)]
|
||||
images = "\n".join(images)
|
||||
start_flag = "<!-- generated images start -->\n"
|
||||
end_flag = "<!-- generated images end -->\n"
|
||||
return f"\n{start_flag}{images}\n{end_flag}\n"
|
||||
|
@ -18,7 +18,14 @@ __all__ = [
|
||||
'AsyncGenerator',
|
||||
'Generator',
|
||||
'Tuple',
|
||||
'Union',
|
||||
'List',
|
||||
'Dict',
|
||||
'Type',
|
||||
'TypedDict',
|
||||
'SHA256',
|
||||
'CreateResult',
|
||||
'AsyncResult',
|
||||
'Messages',
|
||||
'ImageType'
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user