mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
mistralai[minor]: 0.1.0rc0, remove mistral sdk (#19420)
This commit is contained in:
parent
e980c14d6a
commit
53ac1ebbbc
@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
@ -18,6 +18,8 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, aconnect_sse, connect_sse
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -54,19 +56,6 @@ from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
from mistralai.async_client import MistralAsyncClient
|
||||
from mistralai.client import MistralClient
|
||||
from mistralai.constants import ENDPOINT as DEFAULT_MISTRAL_ENDPOINT
|
||||
from mistralai.exceptions import (
|
||||
MistralAPIException,
|
||||
MistralConnectionException,
|
||||
MistralException,
|
||||
)
|
||||
from mistralai.models.chat_completion import (
|
||||
ChatCompletionResponse as MistralChatCompletionResponse,
|
||||
)
|
||||
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
|
||||
from mistralai.models.chat_completion import DeltaMessage as MistralDeltaMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -79,36 +68,34 @@ def _create_retry_decorator(
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
||||
|
||||
errors = [
|
||||
MistralException,
|
||||
MistralAPIException,
|
||||
MistralConnectionException,
|
||||
]
|
||||
errors = [httpx.RequestError, httpx.StreamError]
|
||||
return create_base_retry_decorator(
|
||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||
)
|
||||
|
||||
|
||||
def _convert_mistral_chat_message_to_message(
|
||||
_message: MistralChatMessage,
|
||||
_message: Dict,
|
||||
) -> BaseMessage:
|
||||
role = _message.role
|
||||
content = cast(Union[str, List], _message.content)
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
additional_kwargs: Dict = {}
|
||||
if hasattr(_message, "tool_calls") and getattr(_message, "tool_calls"):
|
||||
additional_kwargs["tool_calls"] = [
|
||||
tc.model_dump() for tc in getattr(_message, "tool_calls")
|
||||
]
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
elif role == "tool":
|
||||
return ToolMessage(content=content, name=_message.name) # type: ignore[attr-defined]
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
role = _message["role"]
|
||||
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
|
||||
content = cast(str, _message["content"])
|
||||
|
||||
additional_kwargs: Dict = {}
|
||||
if tool_calls := _message.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls]
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
|
||||
|
||||
async def _aiter_sse(
|
||||
event_source_mgr: AsyncContextManager[EventSource],
|
||||
) -> AsyncIterator[Dict]:
|
||||
"""Iterate over the server-sent events."""
|
||||
async with event_source_mgr as event_source:
|
||||
async for event in event_source.aiter_sse():
|
||||
if event.data == "[DONE]":
|
||||
return
|
||||
yield event.json()
|
||||
|
||||
|
||||
async def acompletion_with_retry(
|
||||
@ -121,28 +108,33 @@ async def acompletion_with_retry(
|
||||
|
||||
@retry_decorator
|
||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
stream = kwargs.pop("stream", False)
|
||||
if "stream" not in kwargs:
|
||||
kwargs["stream"] = False
|
||||
stream = kwargs["stream"]
|
||||
if stream:
|
||||
return llm.async_client.chat_stream(**kwargs)
|
||||
event_source = aconnect_sse(
|
||||
llm.async_client, "POST", "/chat/completions", json=kwargs
|
||||
)
|
||||
|
||||
return _aiter_sse(event_source)
|
||||
else:
|
||||
return await llm.async_client.chat(**kwargs)
|
||||
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
|
||||
return response.json()
|
||||
|
||||
return await _completion_with_retry(**kwargs)
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_delta: MistralDeltaMessage, default_class: Type[BaseMessageChunk]
|
||||
_delta: Dict, default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = getattr(_delta, "role")
|
||||
content = getattr(_delta, "content", "")
|
||||
role = _delta.get("role")
|
||||
content = _delta.get("content", "")
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
additional_kwargs: Dict = {}
|
||||
if hasattr(_delta, "tool_calls") and getattr(_delta, "tool_calls"):
|
||||
additional_kwargs["tool_calls"] = [
|
||||
tc.model_dump() for tc in getattr(_delta, "tool_calls")
|
||||
]
|
||||
if tool_calls := _delta.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls]
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
@ -154,44 +146,48 @@ def _convert_delta_to_message_chunk(
|
||||
|
||||
def _convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage,
|
||||
) -> MistralChatMessage:
|
||||
) -> Dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
mistral_message = MistralChatMessage(role=message.role, content=message.content)
|
||||
return dict(role=message.role, content=message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
mistral_message = MistralChatMessage(role="user", content=message.content)
|
||||
return dict(role="user", content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
if "tool_calls" in message.additional_kwargs:
|
||||
from mistralai.models.chat_completion import ( # type: ignore[attr-defined]
|
||||
ToolCall as MistralToolCall,
|
||||
)
|
||||
|
||||
tool_calls = [
|
||||
MistralToolCall.model_validate(tc)
|
||||
{
|
||||
"function": {
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"],
|
||||
}
|
||||
}
|
||||
for tc in message.additional_kwargs["tool_calls"]
|
||||
]
|
||||
else:
|
||||
tool_calls = None
|
||||
mistral_message = MistralChatMessage(
|
||||
role="assistant", content=message.content, tool_calls=tool_calls
|
||||
)
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
elif isinstance(message, SystemMessage):
|
||||
mistral_message = MistralChatMessage(role="system", content=message.content)
|
||||
return dict(role="system", content=message.content)
|
||||
elif isinstance(message, ToolMessage):
|
||||
mistral_message = MistralChatMessage(
|
||||
role="tool", content=message.content, name=message.name
|
||||
)
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return mistral_message
|
||||
|
||||
|
||||
class ChatMistralAI(BaseChatModel):
|
||||
"""A chat model that uses the MistralAI API."""
|
||||
|
||||
client: MistralClient = Field(default=None) #: :meta private:
|
||||
async_client: MistralAsyncClient = Field(default=None) #: :meta private:
|
||||
client: httpx.Client = Field(default=None) #: :meta private:
|
||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||
mistral_api_key: Optional[SecretStr] = None
|
||||
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
|
||||
endpoint: str = "https://api.mistral.ai/v1"
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
max_concurrent_requests: int = 64
|
||||
@ -204,6 +200,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
|
||||
random_seed: Optional[int] = None
|
||||
safe_mode: bool = False
|
||||
streaming: bool = False
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
@ -214,7 +211,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"random_seed": self.random_seed,
|
||||
"safe_mode": self.safe_mode,
|
||||
"safe_prompt": self.safe_mode,
|
||||
}
|
||||
filtered = {k: v for k, v in defaults.items() if v is not None}
|
||||
return filtered
|
||||
@ -228,45 +225,60 @@ class ChatMistralAI(BaseChatModel):
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
# retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
|
||||
@retry_decorator
|
||||
# @retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
stream = kwargs.pop("stream", False)
|
||||
if "stream" not in kwargs:
|
||||
kwargs["stream"] = False
|
||||
stream = kwargs["stream"]
|
||||
if stream:
|
||||
return self.client.chat_stream(**kwargs)
|
||||
else:
|
||||
return self.client.chat(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
def iter_sse() -> Iterator[Dict]:
|
||||
with connect_sse(
|
||||
self.client, "POST", "/chat/completions", json=kwargs
|
||||
) as event_source:
|
||||
for event in event_source.iter_sse():
|
||||
if event.data == "[DONE]":
|
||||
return
|
||||
yield event.json()
|
||||
|
||||
return iter_sse()
|
||||
else:
|
||||
return self.client.post(url="/chat/completions", json=kwargs).json()
|
||||
|
||||
rtn = _completion_with_retry(**kwargs)
|
||||
return rtn
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
mistralai_spec = importlib.util.find_spec("mistralai")
|
||||
if mistralai_spec is None:
|
||||
raise MistralException(
|
||||
"Could not find mistralai python package. "
|
||||
"Please install it with `pip install mistralai`"
|
||||
)
|
||||
|
||||
values["mistral_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
||||
)
|
||||
)
|
||||
values["client"] = MistralClient(
|
||||
api_key=values["mistral_api_key"].get_secret_value(),
|
||||
endpoint=values["endpoint"],
|
||||
max_retries=values["max_retries"],
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
# todo: handle retries
|
||||
values["client"] = httpx.Client(
|
||||
base_url=values["endpoint"],
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
)
|
||||
values["async_client"] = MistralAsyncClient(
|
||||
api_key=values["mistral_api_key"].get_secret_value(),
|
||||
endpoint=values["endpoint"],
|
||||
max_retries=values["max_retries"],
|
||||
# todo: handle retries and max_concurrency
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
base_url=values["endpoint"],
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
max_concurrent_requests=values["max_concurrent_requests"],
|
||||
)
|
||||
|
||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||
@ -285,7 +297,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else False
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
if should_stream:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
@ -299,27 +311,23 @@ class ChatMistralAI(BaseChatModel):
|
||||
)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _create_chat_result(
|
||||
self, response: MistralChatCompletionResponse
|
||||
) -> ChatResult:
|
||||
def _create_chat_result(self, response: Dict) -> ChatResult:
|
||||
generations = []
|
||||
for res in response.choices:
|
||||
finish_reason = getattr(res, "finish_reason")
|
||||
if finish_reason:
|
||||
finish_reason = finish_reason.value
|
||||
for res in response["choices"]:
|
||||
finish_reason = res.get("finish_reason")
|
||||
gen = ChatGeneration(
|
||||
message=_convert_mistral_chat_message_to_message(res.message),
|
||||
message=_convert_mistral_chat_message_to_message(res["message"]),
|
||||
generation_info={"finish_reason": finish_reason},
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = getattr(response, "usage")
|
||||
token_usage = vars(token_usage) if token_usage else {}
|
||||
token_usage = response.get("usage", {})
|
||||
|
||||
llm_output = {"token_usage": token_usage, "model": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[MistralChatMessage], Dict[str, Any]]:
|
||||
) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None or "stop" in params:
|
||||
if "stop" in params:
|
||||
@ -340,20 +348,24 @@ class ChatMistralAI(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
for chunk in self.completion_with_retry(
|
||||
messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.choices) == 0:
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta.content:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if not delta["content"]:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
run_manager.on_llm_new_token(
|
||||
token=cast(str, new_chunk.content), chunk=gen_chunk
|
||||
)
|
||||
yield gen_chunk
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@ -365,20 +377,24 @@ class ChatMistralAI(BaseChatModel):
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
params = {**params, **kwargs, "stream": True}
|
||||
|
||||
default_chunk_class = AIMessageChunk
|
||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||
async for chunk in await acompletion_with_retry(
|
||||
self, messages=message_dicts, run_manager=run_manager, **params
|
||||
):
|
||||
if len(chunk.choices) == 0:
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta.content:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if not delta["content"]:
|
||||
continue
|
||||
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
default_chunk_class = chunk.__class__
|
||||
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
|
||||
# make future chunks same type as first chunk
|
||||
default_chunk_class = new_chunk.__class__
|
||||
gen_chunk = ChatGenerationChunk(message=new_chunk)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(token=chunk.content, chunk=chunk)
|
||||
yield ChatGenerationChunk(message=chunk)
|
||||
await run_manager.on_llm_new_token(
|
||||
token=cast(str, new_chunk.content), chunk=gen_chunk
|
||||
)
|
||||
yield gen_chunk
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import httpx
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
@ -11,12 +12,6 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from mistralai.async_client import MistralAsyncClient
|
||||
from mistralai.client import MistralClient
|
||||
from mistralai.constants import (
|
||||
ENDPOINT as DEFAULT_MISTRAL_ENDPOINT,
|
||||
)
|
||||
from mistralai.exceptions import MistralException
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -40,10 +35,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
"""
|
||||
|
||||
client: MistralClient = Field(default=None) #: :meta private:
|
||||
async_client: MistralAsyncClient = Field(default=None) #: :meta private:
|
||||
client: httpx.Client = Field(default=None) #: :meta private:
|
||||
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
|
||||
mistral_api_key: Optional[SecretStr] = None
|
||||
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
|
||||
endpoint: str = "https://api.mistral.ai/v1/"
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
max_concurrent_requests: int = 64
|
||||
@ -64,18 +59,26 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
||||
)
|
||||
)
|
||||
values["client"] = MistralClient(
|
||||
api_key=values["mistral_api_key"].get_secret_value(),
|
||||
endpoint=values["endpoint"],
|
||||
max_retries=values["max_retries"],
|
||||
api_key_str = values["mistral_api_key"].get_secret_value()
|
||||
# todo: handle retries
|
||||
values["client"] = httpx.Client(
|
||||
base_url=values["endpoint"],
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
)
|
||||
values["async_client"] = MistralAsyncClient(
|
||||
api_key=values["mistral_api_key"].get_secret_value(),
|
||||
endpoint=values["endpoint"],
|
||||
max_retries=values["max_retries"],
|
||||
# todo: handle retries and max_concurrency
|
||||
values["async_client"] = httpx.AsyncClient(
|
||||
base_url=values["endpoint"],
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key_str}",
|
||||
},
|
||||
timeout=values["timeout"],
|
||||
max_concurrent_requests=values["max_concurrent_requests"],
|
||||
)
|
||||
if values["tokenizer"] is None:
|
||||
values["tokenizer"] = Tokenizer.from_pretrained(
|
||||
@ -115,18 +118,21 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
try:
|
||||
batch_responses = (
|
||||
self.client.embeddings(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
self.client.post(
|
||||
url="/embeddings",
|
||||
json=dict(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
),
|
||||
)
|
||||
for batch in self._get_batches(texts)
|
||||
)
|
||||
return [
|
||||
list(map(float, embedding_obj.embedding))
|
||||
list(map(float, embedding_obj["embedding"]))
|
||||
for response in batch_responses
|
||||
for embedding_obj in response.data
|
||||
for embedding_obj in response.json()["data"]
|
||||
]
|
||||
except MistralException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
raise
|
||||
|
||||
@ -142,19 +148,22 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
||||
try:
|
||||
batch_responses = await asyncio.gather(
|
||||
*[
|
||||
self.async_client.embeddings(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
self.async_client.post(
|
||||
url="/embeddings",
|
||||
json=dict(
|
||||
model=self.model,
|
||||
input=batch,
|
||||
),
|
||||
)
|
||||
for batch in self._get_batches(texts)
|
||||
]
|
||||
)
|
||||
return [
|
||||
list(map(float, embedding_obj.embedding))
|
||||
list(map(float, embedding_obj["embedding"]))
|
||||
for response in batch_responses
|
||||
for embedding_obj in response.data
|
||||
for embedding_obj in response.json()["data"]
|
||||
]
|
||||
except MistralException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred with MistralAI: {e}")
|
||||
raise
|
||||
|
||||
|
297
libs/partners/mistralai/poetry.lock
generated
297
libs/partners/mistralai/poetry.lock
generated
@ -206,13 +206,13 @@ typing = ["typing-extensions (>=4.8)"]
|
||||
|
||||
[[package]]
|
||||
name = "fsspec"
|
||||
version = "2024.2.0"
|
||||
version = "2024.3.1"
|
||||
description = "File-system specification"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"},
|
||||
{file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"},
|
||||
{file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
|
||||
{file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@ -273,13 +273,13 @@ trio = ["trio (>=0.22.0,<0.25.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.25.2"
|
||||
version = "0.27.0"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"},
|
||||
{file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"},
|
||||
{file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"},
|
||||
{file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -295,15 +295,26 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.0"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
|
||||
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.20.3"
|
||||
version = "0.21.4"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.20.3-py3-none-any.whl", hash = "sha256:d988ae4f00d3e307b0c80c6a05ca6dbb7edba8bba3079f74cda7d9c2e562a7b6"},
|
||||
{file = "huggingface_hub-0.20.3.tar.gz", hash = "sha256:94e7f8e074475fbc67d6a71957b678e1b4a74ff1b64a644fd6cbb83da962d05d"},
|
||||
{file = "huggingface_hub-0.21.4-py3-none-any.whl", hash = "sha256:df37c2c37fc6c82163cdd8a67ede261687d80d1e262526d6c0ce73b6b3630a7b"},
|
||||
{file = "huggingface_hub-0.21.4.tar.gz", hash = "sha256:e1f4968c93726565a80edf6dc309763c7b546d0cfe79aa221206034d50155531"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -320,11 +331,12 @@ all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi",
|
||||
cli = ["InquirerPy (==0.3.4)"]
|
||||
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "mypy (==1.5.1)", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.1.3)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
|
||||
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
|
||||
hf-transfer = ["hf-transfer (>=0.1.4)"]
|
||||
inference = ["aiohttp", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)"]
|
||||
quality = ["mypy (==1.5.1)", "ruff (>=0.1.3)"]
|
||||
tensorflow = ["graphviz", "pydot", "tensorflow"]
|
||||
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "numpy", "pydantic (>1.1,<2.0)", "pydantic (>1.1,<3.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
|
||||
torch = ["torch"]
|
||||
torch = ["safetensors", "torch"]
|
||||
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
|
||||
|
||||
[[package]]
|
||||
@ -376,7 +388,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.27"
|
||||
version = "0.1.33"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -402,13 +414,13 @@ url = "../../core"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.8"
|
||||
version = "0.1.31"
|
||||
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
python-versions = "<4.0,>=3.8.1"
|
||||
files = [
|
||||
{file = "langsmith-0.1.8-py3-none-any.whl", hash = "sha256:f4320fd80ec9d311a648e7d4c44e0814e6e5454772c5026f40db0307bc07e287"},
|
||||
{file = "langsmith-0.1.8.tar.gz", hash = "sha256:ab5f1cdfb7d418109ea506d41928fb8708547db2f6c7f7da7cfe997f3c55767b"},
|
||||
{file = "langsmith-0.1.31-py3-none-any.whl", hash = "sha256:5211a9dc00831db307eb843485a97096484b697b5d2cd1efaac34228e97ca087"},
|
||||
{file = "langsmith-0.1.31.tar.gz", hash = "sha256:efd54ccd44be7fda911bfdc0ead340473df2fdd07345c7252901834d0c4aa37e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -416,40 +428,6 @@ orjson = ">=3.9.14,<4.0.0"
|
||||
pydantic = ">=1,<3"
|
||||
requests = ">=2,<3"
|
||||
|
||||
[[package]]
|
||||
name = "mistralai"
|
||||
version = "0.0.12"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.8,<4.0"
|
||||
files = [
|
||||
{file = "mistralai-0.0.12-py3-none-any.whl", hash = "sha256:d489d1f0a31bf0edbe15c6d12f68b943148d2a725a088be0d8a5d4c888f8436c"},
|
||||
{file = "mistralai-0.0.12.tar.gz", hash = "sha256:fe652836146a15bdce7691a95803a32c53c641c5400093447ffa93bf2ed296b2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.25.2,<0.26.0"
|
||||
orjson = ">=3.9.10,<4.0.0"
|
||||
pydantic = ">=2.5.2,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "mistralai"
|
||||
version = "0.1.2"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.9,<4.0"
|
||||
files = [
|
||||
{file = "mistralai-0.1.2-py3-none-any.whl", hash = "sha256:5e74e5ef0c0f15058892d73b00c659e06e9882c00838a1ad9862d93c77336847"},
|
||||
{file = "mistralai-0.1.2.tar.gz", hash = "sha256:eb915fd15075f71bdbfce9cb476bb647322b1ce1e93b19ab0047728067466397"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.25.2,<0.26.0"
|
||||
orjson = ">=3.9.10,<4.0.0"
|
||||
pandas = ">=2.2.0,<3.0.0"
|
||||
pyarrow = ">=15.0.0,<16.0.0"
|
||||
pydantic = ">=2.5.2,<3.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "mypy"
|
||||
version = "0.991"
|
||||
@ -511,51 +489,6 @@ files = [
|
||||
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "numpy"
|
||||
version = "1.26.4"
|
||||
description = "Fundamental package for array computing in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07"},
|
||||
{file = "numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20"},
|
||||
{file = "numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110"},
|
||||
{file = "numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6"},
|
||||
{file = "numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c"},
|
||||
{file = "numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0"},
|
||||
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.9.15"
|
||||
@ -626,79 +559,6 @@ files = [
|
||||
{file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pandas"
|
||||
version = "2.2.1"
|
||||
description = "Powerful data structures for data analysis, time series, and statistics"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pandas-2.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8df8612be9cd1c7797c93e1c5df861b2ddda0b48b08f2c3eaa0702cf88fb5f88"},
|
||||
{file = "pandas-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0f573ab277252ed9aaf38240f3b54cfc90fff8e5cab70411ee1d03f5d51f3944"},
|
||||
{file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f02a3a6c83df4026e55b63c1f06476c9aa3ed6af3d89b4f04ea656ccdaaaa359"},
|
||||
{file = "pandas-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c38ce92cb22a4bea4e3929429aa1067a454dcc9c335799af93ba9be21b6beb51"},
|
||||
{file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c2ce852e1cf2509a69e98358e8458775f89599566ac3775e70419b98615f4b06"},
|
||||
{file = "pandas-2.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:53680dc9b2519cbf609c62db3ed7c0b499077c7fefda564e330286e619ff0dd9"},
|
||||
{file = "pandas-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:94e714a1cca63e4f5939cdce5f29ba8d415d85166be3441165edd427dc9f6bc0"},
|
||||
{file = "pandas-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f821213d48f4ab353d20ebc24e4faf94ba40d76680642fb7ce2ea31a3ad94f9b"},
|
||||
{file = "pandas-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c70e00c2d894cb230e5c15e4b1e1e6b2b478e09cf27cc593a11ef955b9ecc81a"},
|
||||
{file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e97fbb5387c69209f134893abc788a6486dbf2f9e511070ca05eed4b930b1b02"},
|
||||
{file = "pandas-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:101d0eb9c5361aa0146f500773395a03839a5e6ecde4d4b6ced88b7e5a1a6403"},
|
||||
{file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:7d2ed41c319c9fb4fd454fe25372028dfa417aacb9790f68171b2e3f06eae8cd"},
|
||||
{file = "pandas-2.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:af5d3c00557d657c8773ef9ee702c61dd13b9d7426794c9dfeb1dc4a0bf0ebc7"},
|
||||
{file = "pandas-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:06cf591dbaefb6da9de8472535b185cba556d0ce2e6ed28e21d919704fef1a9e"},
|
||||
{file = "pandas-2.2.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:88ecb5c01bb9ca927ebc4098136038519aa5d66b44671861ffab754cae75102c"},
|
||||
{file = "pandas-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04f6ec3baec203c13e3f8b139fb0f9f86cd8c0b94603ae3ae8ce9a422e9f5bee"},
|
||||
{file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a935a90a76c44fe170d01e90a3594beef9e9a6220021acfb26053d01426f7dc2"},
|
||||
{file = "pandas-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c391f594aae2fd9f679d419e9a4d5ba4bce5bb13f6a989195656e7dc4b95c8f0"},
|
||||
{file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9d1265545f579edf3f8f0cb6f89f234f5e44ba725a34d86535b1a1d38decbccc"},
|
||||
{file = "pandas-2.2.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:11940e9e3056576ac3244baef2fedade891977bcc1cb7e5cc8f8cc7d603edc89"},
|
||||
{file = "pandas-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:4acf681325ee1c7f950d058b05a820441075b0dd9a2adf5c4835b9bc056bf4fb"},
|
||||
{file = "pandas-2.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9bd8a40f47080825af4317d0340c656744f2bfdb6819f818e6ba3cd24c0e1397"},
|
||||
{file = "pandas-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df0c37ebd19e11d089ceba66eba59a168242fc6b7155cba4ffffa6eccdfb8f16"},
|
||||
{file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:739cc70eaf17d57608639e74d63387b0d8594ce02f69e7a0b046f117974b3019"},
|
||||
{file = "pandas-2.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f9d3558d263073ed95e46f4650becff0c5e1ffe0fc3a015de3c79283dfbdb3df"},
|
||||
{file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:4aa1d8707812a658debf03824016bf5ea0d516afdea29b7dc14cf687bc4d4ec6"},
|
||||
{file = "pandas-2.2.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:76f27a809cda87e07f192f001d11adc2b930e93a2b0c4a236fde5429527423be"},
|
||||
{file = "pandas-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:1ba21b1d5c0e43416218db63037dbe1a01fc101dc6e6024bcad08123e48004ab"},
|
||||
{file = "pandas-2.2.1.tar.gz", hash = "sha256:0ab90f87093c13f3e8fa45b48ba9f39181046e8f3317d3aadb2fffbb1b978572"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
tzdata = ">=2022.7"
|
||||
|
||||
[package.extras]
|
||||
all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"]
|
||||
aws = ["s3fs (>=2022.11.0)"]
|
||||
clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"]
|
||||
compression = ["zstandard (>=0.19.0)"]
|
||||
computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"]
|
||||
consortium-standard = ["dataframe-api-compat (>=0.1.7)"]
|
||||
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"]
|
||||
feather = ["pyarrow (>=10.0.1)"]
|
||||
fss = ["fsspec (>=2022.11.0)"]
|
||||
gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"]
|
||||
hdf5 = ["tables (>=3.8.0)"]
|
||||
html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"]
|
||||
mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"]
|
||||
output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"]
|
||||
parquet = ["pyarrow (>=10.0.1)"]
|
||||
performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"]
|
||||
plot = ["matplotlib (>=3.6.3)"]
|
||||
postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"]
|
||||
pyarrow = ["pyarrow (>=10.0.1)"]
|
||||
spss = ["pyreadstat (>=1.2.0)"]
|
||||
sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"]
|
||||
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
|
||||
xml = ["lxml (>=4.9.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.4.0"
|
||||
@ -714,63 +574,15 @@ files = [
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "15.0.0"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyarrow-15.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:0a524532fd6dd482edaa563b686d754c70417c2f72742a8c990b322d4c03a15d"},
|
||||
{file = "pyarrow-15.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a6bdb314affa9c2e0d5dddf3d9cbb9ef4a8dddaa68669975287d47ece67642"},
|
||||
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:66958fd1771a4d4b754cd385835e66a3ef6b12611e001d4e5edfcef5f30391e2"},
|
||||
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f500956a49aadd907eaa21d4fff75f73954605eaa41f61cb94fb008cf2e00c6"},
|
||||
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6f87d9c4f09e049c2cade559643424da84c43a35068f2a1c4653dc5b1408a929"},
|
||||
{file = "pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:85239b9f93278e130d86c0e6bb455dcb66fc3fd891398b9d45ace8799a871a1e"},
|
||||
{file = "pyarrow-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:5b8d43e31ca16aa6e12402fcb1e14352d0d809de70edd185c7650fe80e0769e3"},
|
||||
{file = "pyarrow-15.0.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:fa7cd198280dbd0c988df525e50e35b5d16873e2cdae2aaaa6363cdb64e3eec5"},
|
||||
{file = "pyarrow-15.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8780b1a29d3c8b21ba6b191305a2a607de2e30dab399776ff0aa09131e266340"},
|
||||
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fe0ec198ccc680f6c92723fadcb97b74f07c45ff3fdec9dd765deb04955ccf19"},
|
||||
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:036a7209c235588c2f07477fe75c07e6caced9b7b61bb897c8d4e52c4b5f9555"},
|
||||
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2bd8a0e5296797faf9a3294e9fa2dc67aa7f10ae2207920dbebb785c77e9dbe5"},
|
||||
{file = "pyarrow-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e8ebed6053dbe76883a822d4e8da36860f479d55a762bd9e70d8494aed87113e"},
|
||||
{file = "pyarrow-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:17d53a9d1b2b5bd7d5e4cd84d018e2a45bc9baaa68f7e6e3ebed45649900ba99"},
|
||||
{file = "pyarrow-15.0.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:9950a9c9df24090d3d558b43b97753b8f5867fb8e521f29876aa021c52fda351"},
|
||||
{file = "pyarrow-15.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:003d680b5e422d0204e7287bb3fa775b332b3fce2996aa69e9adea23f5c8f970"},
|
||||
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f75fce89dad10c95f4bf590b765e3ae98bcc5ba9f6ce75adb828a334e26a3d40"},
|
||||
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ca9cb0039923bec49b4fe23803807e4ef39576a2bec59c32b11296464623dc2"},
|
||||
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:9ed5a78ed29d171d0acc26a305a4b7f83c122d54ff5270810ac23c75813585e4"},
|
||||
{file = "pyarrow-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:6eda9e117f0402dfcd3cd6ec9bfee89ac5071c48fc83a84f3075b60efa96747f"},
|
||||
{file = "pyarrow-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a3a6180c0e8f2727e6f1b1c87c72d3254cac909e609f35f22532e4115461177"},
|
||||
{file = "pyarrow-15.0.0-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:19a8918045993349b207de72d4576af0191beef03ea655d8bdb13762f0cd6eac"},
|
||||
{file = "pyarrow-15.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0ec076b32bacb6666e8813a22e6e5a7ef1314c8069d4ff345efa6246bc38593"},
|
||||
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5db1769e5d0a77eb92344c7382d6543bea1164cca3704f84aa44e26c67e320fb"},
|
||||
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2617e3bf9df2a00020dd1c1c6dce5cc343d979efe10bc401c0632b0eef6ef5b"},
|
||||
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:d31c1d45060180131caf10f0f698e3a782db333a422038bf7fe01dace18b3a31"},
|
||||
{file = "pyarrow-15.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c8c287d1d479de8269398b34282e206844abb3208224dbdd7166d580804674b7"},
|
||||
{file = "pyarrow-15.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:07eb7f07dc9ecbb8dace0f58f009d3a29ee58682fcdc91337dfeb51ea618a75b"},
|
||||
{file = "pyarrow-15.0.0-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:47af7036f64fce990bb8a5948c04722e4e3ea3e13b1007ef52dfe0aa8f23cf7f"},
|
||||
{file = "pyarrow-15.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:93768ccfff85cf044c418bfeeafce9a8bb0cee091bd8fd19011aff91e58de540"},
|
||||
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6ee87fd6892700960d90abb7b17a72a5abb3b64ee0fe8db6c782bcc2d0dc0b4"},
|
||||
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:001fca027738c5f6be0b7a3159cc7ba16a5c52486db18160909a0831b063c4e4"},
|
||||
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:d1c48648f64aec09accf44140dccb92f4f94394b8d79976c426a5b79b11d4fa7"},
|
||||
{file = "pyarrow-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:972a0141be402bb18e3201448c8ae62958c9c7923dfaa3b3d4530c835ac81aed"},
|
||||
{file = "pyarrow-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:f01fc5cf49081426429127aa2d427d9d98e1cb94a32cb961d583a70b7c4504e6"},
|
||||
{file = "pyarrow-15.0.0.tar.gz", hash = "sha256:876858f549d540898f927eba4ef77cd549ad8d24baa3207cf1b72e5788b50e83"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.16.6,<2"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.6.2"
|
||||
version = "2.6.4"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.6.2-py3-none-any.whl", hash = "sha256:37a5432e54b12fecaa1049c5195f3d860a10e01bdfd24f1840ef14bd0d3aeab3"},
|
||||
{file = "pydantic-2.6.2.tar.gz", hash = "sha256:a09be1c3d28f3abe37f8a78af58284b236a92ce520105ddc91a6d29ea1176ba7"},
|
||||
{file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"},
|
||||
{file = "pydantic-2.6.4.tar.gz", hash = "sha256:b1704e0847db01817624a6b86766967f552dd9dbf3afba4004409f908dcc84e6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -912,31 +724,6 @@ pytest = ">=7.0.0"
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
description = "Extensions to the standard Python datetime module"
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
||||
files = [
|
||||
{file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"},
|
||||
{file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2024.1"
|
||||
description = "World timezone definitions, modern and historical"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytz-2024.1-py2.py3-none-any.whl", hash = "sha256:328171f4e3623139da4983451950b28e95ac706e13f3f2630a879749e7a8b319"},
|
||||
{file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.1"
|
||||
@ -1044,17 +831,6 @@ files = [
|
||||
{file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "six"
|
||||
version = "1.16.0"
|
||||
description = "Python 2 and 3 compatibility utilities"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
files = [
|
||||
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
|
||||
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
@ -1249,17 +1025,6 @@ files = [
|
||||
{file = "typing_extensions-4.10.0.tar.gz", hash = "sha256:b0abd7c89e8fb96f98db18d86106ff1d90ab692004eb746cf6eda2682f91b3cb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tzdata"
|
||||
version = "2024.1"
|
||||
description = "Provider of IANA time zone data"
|
||||
optional = false
|
||||
python-versions = ">=2"
|
||||
files = [
|
||||
{file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
|
||||
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.2.1"
|
||||
@ -1280,4 +1045,4 @@ zstd = ["zstandard (>=0.18.0)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "ccb95664a734631dde949975506ab160f65cdd222b28bf4f702fb4b11644f418"
|
||||
content-hash = "3d4fde33e55ded42474f7f42fbe34ce877f1deccaffdeed17d4ea26c47d07842"
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-mistralai"
|
||||
version = "0.0.5"
|
||||
version = "0.1.0rc0"
|
||||
description = "An integration package connecting Mistral and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
@ -13,8 +13,9 @@ license = "MIT"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.27"
|
||||
mistralai = [{version = "^0.1", python = "^3.9"}, {version = ">=0.0.11,<0.2", python="3.8"}]
|
||||
tokenizers = "^0.15.1"
|
||||
httpx = ">=0.25.2,<1"
|
||||
httpx-sse = ">=0.3.1,<1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
@ -24,17 +25,17 @@ pytest = "^7.3.0"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test MistralAI Chat API wrapper."""
|
||||
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Generator
|
||||
from typing import Any, AsyncGenerator, Dict, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -13,16 +14,6 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
||||
from mistralai.models.chat_completion import ( # type: ignore[import]
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
)
|
||||
from mistralai.models.chat_completion import (
|
||||
ChatMessage as MistralChatMessage,
|
||||
)
|
||||
|
||||
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
ChatMistralAI,
|
||||
_convert_message_to_mistral_chat_message,
|
||||
@ -31,13 +22,11 @@ from langchain_mistralai.chat_models import ( # type: ignore[import]
|
||||
os.environ["MISTRAL_API_KEY"] = "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("mistralai")
|
||||
def test_mistralai_model_param() -> None:
|
||||
llm = ChatMistralAI(model="foo")
|
||||
assert llm.model == "foo"
|
||||
|
||||
|
||||
@pytest.mark.requires("mistralai")
|
||||
def test_mistralai_initialization() -> None:
|
||||
"""Test ChatMistralAI initialization."""
|
||||
# Verify that ChatMistralAI can be initialized using a secret key provided
|
||||
@ -50,37 +39,37 @@ def test_mistralai_initialization() -> None:
|
||||
[
|
||||
(
|
||||
SystemMessage(content="Hello"),
|
||||
MistralChatMessage(role="system", content="Hello"),
|
||||
dict(role="system", content="Hello"),
|
||||
),
|
||||
(
|
||||
HumanMessage(content="Hello"),
|
||||
MistralChatMessage(role="user", content="Hello"),
|
||||
dict(role="user", content="Hello"),
|
||||
),
|
||||
(
|
||||
AIMessage(content="Hello"),
|
||||
MistralChatMessage(role="assistant", content="Hello"),
|
||||
dict(role="assistant", content="Hello", tool_calls=None),
|
||||
),
|
||||
(
|
||||
ChatMessage(role="assistant", content="Hello"),
|
||||
MistralChatMessage(role="assistant", content="Hello"),
|
||||
dict(role="assistant", content="Hello"),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_convert_message_to_mistral_chat_message(
|
||||
message: BaseMessage, expected: MistralChatMessage
|
||||
message: BaseMessage, expected: Dict
|
||||
) -> None:
|
||||
result = _convert_message_to_mistral_chat_message(message)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResponse:
|
||||
return ChatCompletionStreamResponse(
|
||||
def _make_completion_response_from_token(token: str) -> Dict:
|
||||
return dict(
|
||||
id="abc123",
|
||||
model="fake_model",
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
dict(
|
||||
index=0,
|
||||
delta=DeltaMessage(content=token),
|
||||
delta=dict(content=token),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
@ -88,13 +77,19 @@ def _make_completion_response_from_token(token: str) -> ChatCompletionStreamResp
|
||||
|
||||
|
||||
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
def it() -> Generator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
|
||||
return it()
|
||||
|
||||
|
||||
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
async def it() -> AsyncGenerator:
|
||||
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
||||
yield _make_completion_response_from_token(token)
|
||||
|
||||
return it()
|
||||
|
||||
|
||||
class MyCustomHandler(BaseCallbackHandler):
|
||||
@ -104,7 +99,10 @@ class MyCustomHandler(BaseCallbackHandler):
|
||||
self.last_token = token
|
||||
|
||||
|
||||
@patch("mistralai.client.MistralClient.chat_stream", new=mock_chat_stream)
|
||||
@patch(
|
||||
"langchain_mistralai.chat_models.ChatMistralAI.completion_with_retry",
|
||||
new=mock_chat_stream,
|
||||
)
|
||||
def test_stream_with_callback() -> None:
|
||||
callback = MyCustomHandler()
|
||||
chat = ChatMistralAI(callbacks=[callback])
|
||||
@ -112,7 +110,7 @@ def test_stream_with_callback() -> None:
|
||||
assert callback.last_token == token.content
|
||||
|
||||
|
||||
@patch("mistralai.async_client.MistralAsyncClient.chat_stream", new=mock_chat_astream)
|
||||
@patch("langchain_mistralai.chat_models.acompletion_with_retry", new=mock_chat_astream)
|
||||
async def test_astream_with_callback() -> None:
|
||||
callback = MyCustomHandler()
|
||||
chat = ChatMistralAI(callbacks=[callback])
|
||||
|
Loading…
Reference in New Issue
Block a user