|
|
|
@ -42,8 +42,8 @@ from langchain_core.outputs import (
|
|
|
|
|
ChatGenerationChunk,
|
|
|
|
|
ChatResult,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.pydantic_v1 import root_validator
|
|
|
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
|
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
|
|
|
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|
|
|
|
|
|
|
|
|
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
|
|
|
|
from mistralai.async_client import MistralAsyncClient # type: ignore[import]
|
|
|
|
@ -111,18 +111,11 @@ async def acompletion_with_retry(
|
|
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
|
|
|
|
client = MistralAsyncClient(
|
|
|
|
|
api_key=llm.mistral_api_key,
|
|
|
|
|
endpoint=llm.endpoint,
|
|
|
|
|
max_retries=llm.max_retries,
|
|
|
|
|
timeout=llm.timeout,
|
|
|
|
|
max_concurrent_requests=llm.max_concurrent_requests,
|
|
|
|
|
)
|
|
|
|
|
stream = kwargs.pop("stream", False)
|
|
|
|
|
if stream:
|
|
|
|
|
return client.chat_stream(**kwargs)
|
|
|
|
|
return llm.async_client.chat_stream(**kwargs)
|
|
|
|
|
else:
|
|
|
|
|
return await client.chat(**kwargs)
|
|
|
|
|
return await llm.async_client.chat(**kwargs)
|
|
|
|
|
|
|
|
|
|
return await _completion_with_retry(**kwargs)
|
|
|
|
|
|
|
|
|
@ -163,8 +156,9 @@ def _convert_message_to_mistral_chat_message(
|
|
|
|
|
class ChatMistralAI(BaseChatModel):
|
|
|
|
|
"""A chat model that uses the MistralAI API."""
|
|
|
|
|
|
|
|
|
|
client: Any #: :meta private:
|
|
|
|
|
mistral_api_key: Optional[str] = None
|
|
|
|
|
client: MistralClient = None #: :meta private:
|
|
|
|
|
async_client: MistralAsyncClient = None #: :meta private:
|
|
|
|
|
mistral_api_key: Optional[SecretStr] = None
|
|
|
|
|
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
|
|
|
|
|
max_retries: int = 5
|
|
|
|
|
timeout: int = 120
|
|
|
|
@ -224,14 +218,23 @@ class ChatMistralAI(BaseChatModel):
|
|
|
|
|
"Please install it with `pip install mistralai`"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
values["mistral_api_key"] = get_from_dict_or_env(
|
|
|
|
|
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
|
|
|
|
values["mistral_api_key"] = convert_to_secret_str(
|
|
|
|
|
get_from_dict_or_env(
|
|
|
|
|
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
values["client"] = MistralClient(
|
|
|
|
|
api_key=values["mistral_api_key"],
|
|
|
|
|
api_key=values["mistral_api_key"].get_secret_value(),
|
|
|
|
|
endpoint=values["endpoint"],
|
|
|
|
|
max_retries=values["max_retries"],
|
|
|
|
|
timeout=values["timeout"],
|
|
|
|
|
)
|
|
|
|
|
values["async_client"] = MistralAsyncClient(
|
|
|
|
|
api_key=values["mistral_api_key"].get_secret_value(),
|
|
|
|
|
endpoint=values["endpoint"],
|
|
|
|
|
max_retries=values["max_retries"],
|
|
|
|
|
timeout=values["timeout"],
|
|
|
|
|
max_concurrent_requests=values["max_concurrent_requests"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
|
|
|
|