mistralai[patch]: persist async client (#15786)

pull/15789/head
Erick Friis 6 months ago committed by GitHub
parent 3e0cd11f51
commit 323941a90a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,8 +42,8 @@ from langchain_core.outputs import (
ChatGenerationChunk,
ChatResult,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.async_client import MistralAsyncClient # type: ignore[import]
@ -111,18 +111,11 @@ async def acompletion_with_retry(
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
client = MistralAsyncClient(
api_key=llm.mistral_api_key,
endpoint=llm.endpoint,
max_retries=llm.max_retries,
timeout=llm.timeout,
max_concurrent_requests=llm.max_concurrent_requests,
)
stream = kwargs.pop("stream", False)
if stream:
return client.chat_stream(**kwargs)
return llm.async_client.chat_stream(**kwargs)
else:
return await client.chat(**kwargs)
return await llm.async_client.chat(**kwargs)
return await _completion_with_retry(**kwargs)
@ -163,8 +156,9 @@ def _convert_message_to_mistral_chat_message(
class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API."""
client: Any #: :meta private:
mistral_api_key: Optional[str] = None
client: MistralClient = None #: :meta private:
async_client: MistralAsyncClient = None #: :meta private:
mistral_api_key: Optional[SecretStr] = None
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
max_retries: int = 5
timeout: int = 120
@ -224,14 +218,23 @@ class ChatMistralAI(BaseChatModel):
"Please install it with `pip install mistralai`"
)
values["mistral_api_key"] = get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
values["client"] = MistralClient(
api_key=values["mistral_api_key"],
api_key=values["mistral_api_key"].get_secret_value(),
endpoint=values["endpoint"],
max_retries=values["max_retries"],
timeout=values["timeout"],
)
values["async_client"] = MistralAsyncClient(
api_key=values["mistral_api_key"].get_secret_value(),
endpoint=values["endpoint"],
max_retries=values["max_retries"],
timeout=values["timeout"],
max_concurrent_requests=values["max_concurrent_requests"],
)
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:

Loading…
Cancel
Save