Fix unittests, use Union typing

pull/1609/head
Heiner Lohaus 4 months ago
parent 74397096b7
commit d733930a2b

@ -35,13 +35,15 @@ class TestPassModel(unittest.TestCase):
response = client.chat.completions.create(messages, "Hello", stream=True) response = client.chat.completions.create(messages, "Hello", stream=True)
for chunk in response: for chunk in response:
self.assertIsInstance(chunk, ChatCompletionChunk) self.assertIsInstance(chunk, ChatCompletionChunk)
self.assertIsInstance(chunk.choices[0].delta.content, str) if chunk.choices[0].delta.content is not None:
self.assertIsInstance(chunk.choices[0].delta.content, str)
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]] messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2) response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
response = list(response) response = list(response)
self.assertEqual(len(response), 2) self.assertEqual(len(response), 3)
for chunk in response: for chunk in response:
self.assertEqual(chunk.choices[0].delta.content, "You ") if chunk.choices[0].delta.content is not None:
self.assertEqual(chunk.choices[0].delta.content, "You ")
def test_stop(self): def test_stop(self):
client = Client(provider=YieldProviderMock) client = Client(provider=YieldProviderMock)

@ -6,7 +6,7 @@ import nest_asyncio
from fastapi import FastAPI, Response, Request from fastapi import FastAPI, Response, Request
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import List, Union
import g4f import g4f
import g4f.debug import g4f.debug
@ -16,12 +16,12 @@ from g4f.typing import Messages
class ChatCompletionsConfig(BaseModel): class ChatCompletionsConfig(BaseModel):
messages: Messages messages: Messages
model: str model: str
provider: str | None provider: Union[str, None]
stream: bool = False stream: bool = False
temperature: float | None temperature: Union[float, None]
max_tokens: int = None max_tokens: int = None
stop: list[str] | str | None stop: Union[list[str], str, None]
access_token: str | None access_token: Union[str, None]
class Api: class Api:
def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False, def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False,

@ -17,7 +17,7 @@ from . import get_model_and_provider, get_last_provider
ImageProvider = Union[BaseProvider, object] ImageProvider = Union[BaseProvider, object]
Proxies = Union[dict, str] Proxies = Union[dict, str]
IterResponse = Generator[ChatCompletion | ChatCompletionChunk, None, None] IterResponse = Generator[Union[ChatCompletion, ChatCompletionChunk], None, None]
def read_json(text: str) -> dict: def read_json(text: str) -> dict:
""" """
@ -124,7 +124,7 @@ class Completions():
stream: bool = False, stream: bool = False,
response_format: dict = None, response_format: dict = None,
max_tokens: int = None, max_tokens: int = None,
stop: list[str] | str = None, stop: Union[list[str], str] = None,
**kwargs **kwargs
) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]:
if max_tokens is not None: if max_tokens is not None:

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Union
class Model(): class Model():
... ...
@ -52,7 +54,7 @@ class ChatCompletionChunk(Model):
} }
class ChatCompletionMessage(Model): class ChatCompletionMessage(Model):
def __init__(self, content: str | None): def __init__(self, content: Union[str, None]):
self.role = "assistant" self.role = "assistant"
self.content = content self.content = content
@ -72,7 +74,9 @@ class ChatCompletionChoice(Model):
} }
class ChatCompletionDelta(Model): class ChatCompletionDelta(Model):
def __init__(self, content: str | None): content: Union[str, None] = None
def __init__(self, content: Union[str, None]):
if content is not None: if content is not None:
self.content = content self.content = content
@ -80,7 +84,7 @@ class ChatCompletionDelta(Model):
return self.__dict__ return self.__dict__
class ChatCompletionDeltaChoice(Model): class ChatCompletionDeltaChoice(Model):
def __init__(self, delta: ChatCompletionDelta, finish_reason: str | None): def __init__(self, delta: ChatCompletionDelta, finish_reason: Union[str, None]):
self.delta = delta self.delta = delta
self.finish_reason = finish_reason self.finish_reason = finish_reason

Loading…
Cancel
Save