Merge pull request #1833 from hlohaus/curl

Update event loop on windows only for old curl_cffi
This commit is contained in:
H Lohaus 2024-04-17 10:35:08 +02:00 committed by GitHub
commit 0f04dacdbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -66,9 +66,12 @@ class LocalProvider:
if message["role"] != "system" if message["role"] != "system"
) + "\nASSISTANT: " ) + "\nASSISTANT: "
def should_not_stop(token_id: int, token: str):
return "USER" not in token
with model.chat_session(system_message, prompt_template): with model.chat_session(system_message, prompt_template):
if stream: if stream:
for token in model.generate(conversation, streaming=True): for token in model.generate(conversation, streaming=True, callback=should_not_stop):
yield token yield token
else: else:
yield model.generate(conversation) yield model.generate(conversation, callback=should_not_stop)

View File

@ -19,8 +19,13 @@ else:
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi # Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32': if sys.platform == 'win32':
try:
from curl_cffi import aio
if not hasattr(aio, "_get_selector"):
if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
except ImportError:
pass
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try: try: