2023-12-19 15:34:19 +00:00
from __future__ import annotations
import logging
mistral[patch]: add IDs to tool calls (#20299)
Mistral gives us one ID per response, no individual IDs for tool calls.
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-04-11 15:09:30 +00:00
import uuid
2024-02-27 00:22:30 +00:00
from operator import itemgetter
2023-12-19 15:34:19 +00:00
from typing import (
Any ,
2024-03-22 01:24:58 +00:00
AsyncContextManager ,
2023-12-19 15:34:19 +00:00
AsyncIterator ,
Callable ,
Dict ,
Iterator ,
List ,
Optional ,
2024-02-27 00:22:30 +00:00
Sequence ,
2023-12-19 15:34:19 +00:00
Tuple ,
Type ,
Union ,
2024-02-27 00:22:30 +00:00
cast ,
2023-12-19 15:34:19 +00:00
)
2024-03-22 01:24:58 +00:00
import httpx
from httpx_sse import EventSource , aconnect_sse , connect_sse
2024-02-27 00:22:30 +00:00
from langchain_core . _api import beta
2023-12-19 15:34:19 +00:00
from langchain_core . callbacks import (
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
)
2024-02-27 00:22:30 +00:00
from langchain_core . language_models import LanguageModelInput
2023-12-19 15:34:19 +00:00
from langchain_core . language_models . chat_models import (
BaseChatModel ,
agenerate_from_stream ,
generate_from_stream ,
)
from langchain_core . language_models . llms import create_base_retry_decorator
from langchain_core . messages import (
AIMessage ,
AIMessageChunk ,
BaseMessage ,
BaseMessageChunk ,
ChatMessage ,
ChatMessageChunk ,
HumanMessage ,
HumanMessageChunk ,
SystemMessage ,
SystemMessageChunk ,
2024-02-27 00:22:30 +00:00
ToolMessage ,
2023-12-19 15:34:19 +00:00
)
2024-02-27 00:22:30 +00:00
from langchain_core . output_parsers . base import OutputParserLike
from langchain_core . output_parsers . openai_tools import (
JsonOutputKeyToolsParser ,
PydanticToolsParser ,
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
make_invalid_tool_call ,
parse_tool_call ,
2023-12-19 15:34:19 +00:00
)
2024-02-27 00:22:30 +00:00
from langchain_core . outputs import ChatGeneration , ChatGenerationChunk , ChatResult
from langchain_core . pydantic_v1 import BaseModel , Field , SecretStr , root_validator
from langchain_core . runnables import Runnable , RunnableMap , RunnablePassthrough
from langchain_core . tools import BaseTool
2024-01-10 00:21:39 +00:00
from langchain_core . utils import convert_to_secret_str , get_from_dict_or_env
2024-02-27 00:22:30 +00:00
from langchain_core . utils . function_calling import convert_to_openai_tool
2023-12-19 15:34:19 +00:00
logger = logging . getLogger ( __name__ )
def _create_retry_decorator (
llm : ChatMistralAI ,
run_manager : Optional [
Union [ AsyncCallbackManagerForLLMRun , CallbackManagerForLLMRun ]
] = None ,
) - > Callable [ [ Any ] , Any ] :
""" Returns a tenacity retry decorator, preconfigured to handle exceptions """
2024-03-22 01:24:58 +00:00
errors = [ httpx . RequestError , httpx . StreamError ]
2023-12-19 15:34:19 +00:00
return create_base_retry_decorator (
error_types = errors , max_retries = llm . max_retries , run_manager = run_manager
)
def _convert_mistral_chat_message_to_message (
2024-03-22 01:24:58 +00:00
_message : Dict ,
2023-12-19 15:34:19 +00:00
) - > BaseMessage :
2024-03-22 01:24:58 +00:00
role = _message [ " role " ]
assert role == " assistant " , f " Expected role to be ' assistant ' , got { role } "
content = cast ( str , _message [ " content " ] )
additional_kwargs : Dict = { }
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
tool_calls = [ ]
invalid_tool_calls = [ ]
if raw_tool_calls := _message . get ( " tool_calls " ) :
additional_kwargs [ " tool_calls " ] = raw_tool_calls
for raw_tool_call in raw_tool_calls :
try :
parsed : dict = cast (
mistral[patch]: add IDs to tool calls (#20299)
Mistral gives us one ID per response, no individual IDs for tool calls.
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-04-11 15:09:30 +00:00
dict , parse_tool_call ( raw_tool_call , return_id = True )
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
)
mistral[patch]: add IDs to tool calls (#20299)
Mistral gives us one ID per response, no individual IDs for tool calls.
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-04-11 15:09:30 +00:00
if not parsed [ " id " ] :
tool_call_id = uuid . uuid4 ( ) . hex [ : ]
tool_calls . append (
{
* * parsed ,
* * { " id " : tool_call_id } ,
} ,
)
else :
tool_calls . append ( parsed )
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
except Exception as e :
invalid_tool_calls . append (
dict ( make_invalid_tool_call ( raw_tool_call , str ( e ) ) )
)
return AIMessage (
content = content ,
additional_kwargs = additional_kwargs ,
tool_calls = tool_calls ,
invalid_tool_calls = invalid_tool_calls ,
)
2024-03-22 01:24:58 +00:00
2024-04-17 14:47:56 +00:00
def _raise_on_error ( response : httpx . Response ) - > None :
""" Raise an error if the response is an error. """
if httpx . codes . is_error ( response . status_code ) :
error_message = response . read ( ) . decode ( " utf-8 " )
raise httpx . HTTPStatusError (
f " Error response { response . status_code } "
f " while fetching { response . url } : { error_message } " ,
request = response . request ,
response = response ,
)
async def _araise_on_error ( response : httpx . Response ) - > None :
""" Raise an error if the response is an error. """
if httpx . codes . is_error ( response . status_code ) :
error_message = ( await response . aread ( ) ) . decode ( " utf-8 " )
raise httpx . HTTPStatusError (
f " Error response { response . status_code } "
f " while fetching { response . url } : { error_message } " ,
request = response . request ,
response = response ,
)
2024-03-22 01:24:58 +00:00
async def _aiter_sse (
event_source_mgr : AsyncContextManager [ EventSource ] ,
) - > AsyncIterator [ Dict ] :
""" Iterate over the server-sent events. """
async with event_source_mgr as event_source :
2024-04-17 14:47:56 +00:00
# TODO(Team): Remove after this is fixed in httpx dependency
# https://github.com/florimondmanca/httpx-sse/pull/25/files
await _araise_on_error ( event_source . _response )
2024-03-22 01:24:58 +00:00
async for event in event_source . aiter_sse ( ) :
if event . data == " [DONE] " :
return
yield event . json ( )
2023-12-19 15:34:19 +00:00
async def acompletion_with_retry (
llm : ChatMistralAI ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Any :
""" Use tenacity to retry the async completion call. """
retry_decorator = _create_retry_decorator ( llm , run_manager = run_manager )
@retry_decorator
async def _completion_with_retry ( * * kwargs : Any ) - > Any :
2024-03-22 01:24:58 +00:00
if " stream " not in kwargs :
kwargs [ " stream " ] = False
stream = kwargs [ " stream " ]
2023-12-19 15:34:19 +00:00
if stream :
2024-03-22 01:24:58 +00:00
event_source = aconnect_sse (
llm . async_client , " POST " , " /chat/completions " , json = kwargs
)
return _aiter_sse ( event_source )
2023-12-19 15:34:19 +00:00
else :
2024-03-22 01:24:58 +00:00
response = await llm . async_client . post ( url = " /chat/completions " , json = kwargs )
2024-04-17 14:47:56 +00:00
await _araise_on_error ( response )
2024-03-22 01:24:58 +00:00
return response . json ( )
2023-12-19 15:34:19 +00:00
return await _completion_with_retry ( * * kwargs )
def _convert_delta_to_message_chunk (
2024-03-22 01:24:58 +00:00
_delta : Dict , default_class : Type [ BaseMessageChunk ]
2023-12-19 15:34:19 +00:00
) - > BaseMessageChunk :
2024-03-22 01:24:58 +00:00
role = _delta . get ( " role " )
2024-03-23 19:24:53 +00:00
content = _delta . get ( " content " ) or " "
2023-12-19 15:34:19 +00:00
if role == " user " or default_class == HumanMessageChunk :
return HumanMessageChunk ( content = content )
elif role == " assistant " or default_class == AIMessageChunk :
2024-02-27 00:22:30 +00:00
additional_kwargs : Dict = { }
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
if raw_tool_calls := _delta . get ( " tool_calls " ) :
additional_kwargs [ " tool_calls " ] = raw_tool_calls
try :
mistral[patch]: add IDs to tool calls (#20299)
Mistral gives us one ID per response, no individual IDs for tool calls.
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-04-11 15:09:30 +00:00
tool_call_chunks = [ ]
for raw_tool_call in raw_tool_calls :
if not raw_tool_call . get ( " index " ) and not raw_tool_call . get ( " id " ) :
tool_call_id = uuid . uuid4 ( ) . hex [ : ]
else :
tool_call_id = raw_tool_call . get ( " id " )
tool_call_chunks . append (
{
" name " : raw_tool_call [ " function " ] . get ( " name " ) ,
" args " : raw_tool_call [ " function " ] . get ( " arguments " ) ,
" id " : tool_call_id ,
" index " : raw_tool_call . get ( " index " ) ,
}
)
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
except KeyError :
pass
else :
tool_call_chunks = [ ]
return AIMessageChunk (
content = content ,
additional_kwargs = additional_kwargs ,
tool_call_chunks = tool_call_chunks ,
)
2023-12-19 15:34:19 +00:00
elif role == " system " or default_class == SystemMessageChunk :
return SystemMessageChunk ( content = content )
elif role or default_class == ChatMessageChunk :
return ChatMessageChunk ( content = content , role = role )
else :
return default_class ( content = content )
def _convert_message_to_mistral_chat_message (
message : BaseMessage ,
2024-03-22 01:24:58 +00:00
) - > Dict :
2023-12-19 15:34:19 +00:00
if isinstance ( message , ChatMessage ) :
2024-03-22 01:24:58 +00:00
return dict ( role = message . role , content = message . content )
2023-12-19 15:34:19 +00:00
elif isinstance ( message , HumanMessage ) :
2024-03-22 01:24:58 +00:00
return dict ( role = " user " , content = message . content )
2023-12-19 15:34:19 +00:00
elif isinstance ( message , AIMessage ) :
2024-02-27 00:22:30 +00:00
if " tool_calls " in message . additional_kwargs :
mistral[patch]: add IDs to tool calls (#20299)
Mistral gives us one ID per response, no individual IDs for tool calls.
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-04-11 15:09:30 +00:00
tool_calls = [ ]
for tc in message . additional_kwargs [ " tool_calls " ] :
chunk = {
2024-03-22 01:24:58 +00:00
" function " : {
" name " : tc [ " function " ] [ " name " ] ,
" arguments " : tc [ " function " ] [ " arguments " ] ,
}
}
mistral[patch]: add IDs to tool calls (#20299)
Mistral gives us one ID per response, no individual IDs for tool calls.
```python
from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_mistralai import ChatMistralAI
prompt = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful assistant"),
("human", "{input}"),
MessagesPlaceholder("agent_scratchpad"),
]
)
model = ChatMistralAI(model="mistral-large-latest", temperature=0)
@tool
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
tools = [magic_function]
agent = create_tool_calling_agent(model, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor.invoke({"input": "what is the value of magic_function(3)?"})
```
---------
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
2024-04-11 15:09:30 +00:00
if _id := tc . get ( " id " ) :
chunk [ " id " ] = _id
tool_calls . append ( chunk )
2024-02-27 00:22:30 +00:00
else :
core[minor], ...: add tool calls message (#18947)
core[minor], langchain[patch], openai[minor], anthropic[minor], fireworks[minor], groq[minor], mistralai[minor]
```python
class ToolCall(TypedDict):
name: str
args: Dict[str, Any]
id: Optional[str]
class InvalidToolCall(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
error: Optional[str]
class ToolCallChunk(TypedDict):
name: Optional[str]
args: Optional[str]
id: Optional[str]
index: Optional[int]
class AIMessage(BaseMessage):
...
tool_calls: List[ToolCall] = []
invalid_tool_calls: List[InvalidToolCall] = []
...
class AIMessageChunk(AIMessage, BaseMessageChunk):
...
tool_call_chunks: Optional[List[ToolCallChunk]] = None
...
```
Important considerations:
- Parsing logic occurs within different providers;
- ~Changing output type is a breaking change for anyone doing explicit
type checking;~
- ~Langsmith rendering will need to be updated:
https://github.com/langchain-ai/langchainplus/pull/3561~
- ~Langserve will need to be updated~
- Adding chunks:
- ~AIMessage + ToolCallsMessage = ToolCallsMessage if either has
non-null .tool_calls.~
- Tool call chunks are appended, merging when having equal values of
`index`.
- additional_kwargs accumulate the normal way.
- During streaming:
- ~Messages can change types (e.g., from AIMessageChunk to
AIToolCallsMessageChunk)~
- Output parsers parse additional_kwargs (during .invoke they read off
tool calls).
Packages outside of `partners/`:
- https://github.com/langchain-ai/langchain-cohere/pull/7
- https://github.com/langchain-ai/langchain-google/pull/123/files
---------
Co-authored-by: Chester Curme <chester.curme@gmail.com>
2024-04-09 23:41:42 +00:00
tool_calls = [ ]
2024-03-22 01:24:58 +00:00
return {
" role " : " assistant " ,
" content " : message . content ,
" tool_calls " : tool_calls ,
}
2023-12-19 15:34:19 +00:00
elif isinstance ( message , SystemMessage ) :
2024-03-22 01:24:58 +00:00
return dict ( role = " system " , content = message . content )
2024-02-27 00:22:30 +00:00
elif isinstance ( message , ToolMessage ) :
2024-03-22 01:24:58 +00:00
return {
" role " : " tool " ,
" content " : message . content ,
" name " : message . name ,
}
2023-12-19 15:34:19 +00:00
else :
raise ValueError ( f " Got unknown type { message } " )
class ChatMistralAI ( BaseChatModel ) :
""" A chat model that uses the MistralAI API. """
2024-03-22 01:24:58 +00:00
client : httpx . Client = Field ( default = None ) #: :meta private:
async_client : httpx . AsyncClient = Field ( default = None ) #: :meta private:
2024-04-08 16:48:38 +00:00
mistral_api_key : Optional [ SecretStr ] = Field ( default = None , alias = " api_key " )
2024-03-22 01:24:58 +00:00
endpoint : str = " https://api.mistral.ai/v1 "
2023-12-19 15:34:19 +00:00
max_retries : int = 5
timeout : int = 120
max_concurrent_requests : int = 64
model : str = " mistral-small "
temperature : float = 0.7
max_tokens : Optional [ int ] = None
top_p : float = 1
""" Decode using nucleus sampling: consider the smallest set of tokens whose
probability sum is at least top_p . Must be in the closed interval [ 0.0 , 1.0 ] . """
random_seed : Optional [ int ] = None
safe_mode : bool = False
2024-03-22 01:24:58 +00:00
streaming : bool = False
2023-12-19 15:34:19 +00:00
2024-04-08 16:48:38 +00:00
class Config :
""" Configuration for this pydantic object. """
allow_population_by_field_name = True
arbitrary_types_allowed = True
2023-12-19 15:34:19 +00:00
@property
def _default_params ( self ) - > Dict [ str , Any ] :
""" Get the default parameters for calling the API. """
defaults = {
" model " : self . model ,
" temperature " : self . temperature ,
" max_tokens " : self . max_tokens ,
" top_p " : self . top_p ,
" random_seed " : self . random_seed ,
2024-03-22 01:24:58 +00:00
" safe_prompt " : self . safe_mode ,
2023-12-19 15:34:19 +00:00
}
filtered = { k : v for k , v in defaults . items ( ) if v is not None }
return filtered
@property
def _client_params ( self ) - > Dict [ str , Any ] :
""" Get the parameters used for the client. """
return self . _default_params
def completion_with_retry (
self , run_manager : Optional [ CallbackManagerForLLMRun ] = None , * * kwargs : Any
) - > Any :
""" Use tenacity to retry the completion call. """
2024-03-22 01:24:58 +00:00
# retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
2023-12-19 15:34:19 +00:00
2024-03-22 01:24:58 +00:00
# @retry_decorator
2023-12-19 15:34:19 +00:00
def _completion_with_retry ( * * kwargs : Any ) - > Any :
2024-03-22 01:24:58 +00:00
if " stream " not in kwargs :
kwargs [ " stream " ] = False
stream = kwargs [ " stream " ]
2023-12-19 15:34:19 +00:00
if stream :
2024-03-22 01:24:58 +00:00
def iter_sse ( ) - > Iterator [ Dict ] :
with connect_sse (
self . client , " POST " , " /chat/completions " , json = kwargs
) as event_source :
2024-04-17 14:47:56 +00:00
# TODO(Team): Remove after this is fixed in httpx dependency
# https://github.com/florimondmanca/httpx-sse/pull/25/files
_raise_on_error ( event_source . _response )
2024-03-22 01:24:58 +00:00
for event in event_source . iter_sse ( ) :
if event . data == " [DONE] " :
return
yield event . json ( )
return iter_sse ( )
2023-12-19 15:34:19 +00:00
else :
2024-04-17 14:47:56 +00:00
response = self . client . post ( url = " /chat/completions " , json = kwargs )
_raise_on_error ( response )
return response . json ( )
2023-12-19 15:34:19 +00:00
2024-03-22 01:24:58 +00:00
rtn = _completion_with_retry ( * * kwargs )
return rtn
2023-12-19 15:34:19 +00:00
2024-03-29 21:43:20 +00:00
def _combine_llm_outputs ( self , llm_outputs : List [ Optional [ dict ] ] ) - > dict :
overall_token_usage : dict = { }
for output in llm_outputs :
if output is None :
# Happens in streaming
continue
token_usage = output [ " token_usage " ]
if token_usage is not None :
for k , v in token_usage . items ( ) :
if k in overall_token_usage :
overall_token_usage [ k ] + = v
else :
overall_token_usage [ k ] = v
combined = { " token_usage " : overall_token_usage , " model_name " : self . model }
return combined
2023-12-19 15:34:19 +00:00
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate api key, python package exists, temperature, and top_p. """
2024-01-10 00:21:39 +00:00
values [ " mistral_api_key " ] = convert_to_secret_str (
get_from_dict_or_env (
values , " mistral_api_key " , " MISTRAL_API_KEY " , default = " "
)
2023-12-19 15:34:19 +00:00
)
2024-03-22 01:24:58 +00:00
api_key_str = values [ " mistral_api_key " ] . get_secret_value ( )
# todo: handle retries
values [ " client " ] = httpx . Client (
base_url = values [ " endpoint " ] ,
headers = {
" Content-Type " : " application/json " ,
" Accept " : " application/json " ,
" Authorization " : f " Bearer { api_key_str } " ,
} ,
2024-01-10 00:21:39 +00:00
timeout = values [ " timeout " ] ,
)
2024-03-22 01:24:58 +00:00
# todo: handle retries and max_concurrency
values [ " async_client " ] = httpx . AsyncClient (
base_url = values [ " endpoint " ] ,
headers = {
" Content-Type " : " application/json " ,
" Accept " : " application/json " ,
" Authorization " : f " Bearer { api_key_str } " ,
} ,
2023-12-19 15:34:19 +00:00
timeout = values [ " timeout " ] ,
)
if values [ " temperature " ] is not None and not 0 < = values [ " temperature " ] < = 1 :
raise ValueError ( " temperature must be in the range [0.0, 1.0] " )
if values [ " top_p " ] is not None and not 0 < = values [ " top_p " ] < = 1 :
raise ValueError ( " top_p must be in the range [0.0, 1.0] " )
return values
def _generate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
stream : Optional [ bool ] = None ,
* * kwargs : Any ,
) - > ChatResult :
2024-03-22 01:24:58 +00:00
should_stream = stream if stream is not None else self . streaming
2023-12-19 15:34:19 +00:00
if should_stream :
stream_iter = self . _stream (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
return generate_from_stream ( stream_iter )
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = { * * params , * * kwargs }
response = self . completion_with_retry (
messages = message_dicts , run_manager = run_manager , * * params
)
return self . _create_chat_result ( response )
2024-03-22 01:24:58 +00:00
def _create_chat_result ( self , response : Dict ) - > ChatResult :
2023-12-19 15:34:19 +00:00
generations = [ ]
2024-03-22 01:24:58 +00:00
for res in response [ " choices " ] :
finish_reason = res . get ( " finish_reason " )
2023-12-19 15:34:19 +00:00
gen = ChatGeneration (
2024-03-22 01:24:58 +00:00
message = _convert_mistral_chat_message_to_message ( res [ " message " ] ) ,
2023-12-19 15:34:19 +00:00
generation_info = { " finish_reason " : finish_reason } ,
)
generations . append ( gen )
2024-03-22 01:24:58 +00:00
token_usage = response . get ( " usage " , { } )
2023-12-19 15:34:19 +00:00
llm_output = { " token_usage " : token_usage , " model " : self . model }
return ChatResult ( generations = generations , llm_output = llm_output )
def _create_message_dicts (
self , messages : List [ BaseMessage ] , stop : Optional [ List [ str ] ]
2024-03-22 01:24:58 +00:00
) - > Tuple [ List [ Dict ] , Dict [ str , Any ] ] :
2023-12-19 15:34:19 +00:00
params = self . _client_params
2024-01-10 00:27:20 +00:00
if stop is not None or " stop " in params :
2023-12-19 15:34:19 +00:00
if " stop " in params :
2024-01-10 00:27:20 +00:00
params . pop ( " stop " )
logger . warning (
" Parameter `stop` not yet supported (https://docs.mistral.ai/api) "
)
2023-12-19 15:34:19 +00:00
message_dicts = [ _convert_message_to_mistral_chat_message ( m ) for m in messages ]
return message_dicts , params
def _stream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ ChatGenerationChunk ] :
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = { * * params , * * kwargs , " stream " : True }
2024-03-22 01:24:58 +00:00
default_chunk_class : Type [ BaseMessageChunk ] = AIMessageChunk
2023-12-19 15:34:19 +00:00
for chunk in self . completion_with_retry (
messages = message_dicts , run_manager = run_manager , * * params
) :
2024-03-22 01:24:58 +00:00
if len ( chunk [ " choices " ] ) == 0 :
2023-12-19 15:34:19 +00:00
continue
2024-03-22 01:24:58 +00:00
delta = chunk [ " choices " ] [ 0 ] [ " delta " ]
new_chunk = _convert_delta_to_message_chunk ( delta , default_chunk_class )
# make future chunks same type as first chunk
default_chunk_class = new_chunk . __class__
gen_chunk = ChatGenerationChunk ( message = new_chunk )
2023-12-19 15:34:19 +00:00
if run_manager :
2024-03-22 01:24:58 +00:00
run_manager . on_llm_new_token (
token = cast ( str , new_chunk . content ) , chunk = gen_chunk
)
yield gen_chunk
2023-12-19 15:34:19 +00:00
async def _astream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ ChatGenerationChunk ] :
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = { * * params , * * kwargs , " stream " : True }
2024-03-22 01:24:58 +00:00
default_chunk_class : Type [ BaseMessageChunk ] = AIMessageChunk
2023-12-19 15:34:19 +00:00
async for chunk in await acompletion_with_retry (
self , messages = message_dicts , run_manager = run_manager , * * params
) :
2024-03-22 01:24:58 +00:00
if len ( chunk [ " choices " ] ) == 0 :
2023-12-19 15:34:19 +00:00
continue
2024-03-22 01:24:58 +00:00
delta = chunk [ " choices " ] [ 0 ] [ " delta " ]
new_chunk = _convert_delta_to_message_chunk ( delta , default_chunk_class )
# make future chunks same type as first chunk
default_chunk_class = new_chunk . __class__
gen_chunk = ChatGenerationChunk ( message = new_chunk )
2023-12-19 15:34:19 +00:00
if run_manager :
2024-03-22 01:24:58 +00:00
await run_manager . on_llm_new_token (
token = cast ( str , new_chunk . content ) , chunk = gen_chunk
)
yield gen_chunk
2023-12-19 15:34:19 +00:00
async def _agenerate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
stream : Optional [ bool ] = None ,
* * kwargs : Any ,
) - > ChatResult :
should_stream = stream if stream is not None else False
if should_stream :
stream_iter = self . _astream (
messages = messages , stop = stop , run_manager = run_manager , * * kwargs
)
return await agenerate_from_stream ( stream_iter )
message_dicts , params = self . _create_message_dicts ( messages , stop )
params = { * * params , * * kwargs }
response = await acompletion_with_retry (
self , messages = message_dicts , run_manager = run_manager , * * params
)
return self . _create_chat_result ( response )
2024-02-27 00:22:30 +00:00
def bind_tools (
self ,
tools : Sequence [ Union [ Dict [ str , Any ] , Type [ BaseModel ] , Callable , BaseTool ] ] ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , BaseMessage ] :
""" Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool - calling API .
Args :
tools : A list of tool definitions to bind to this chat model .
Can be a dictionary , pydantic model , callable , or BaseTool . Pydantic
models , callables , and BaseTools will be automatically converted to
their schema dictionary representation .
tool_choice : Which tool to require the model to call .
Must be the name of the single provided function or
" auto " to automatically determine which function to call
( if any ) , or a dict of the form :
{ " type " : " function " , " function " : { " name " : << tool_name >> } } .
* * kwargs : Any additional parameters to pass to the
: class : ` ~ langchain . runnable . Runnable ` constructor .
"""
formatted_tools = [ convert_to_openai_tool ( tool ) for tool in tools ]
return super ( ) . bind ( tools = formatted_tools , * * kwargs )
@beta ( )
def with_structured_output (
self ,
schema : Union [ Dict , Type [ BaseModel ] ] ,
* ,
include_raw : bool = False ,
* * kwargs : Any ,
) - > Runnable [ LanguageModelInput , Union [ Dict , BaseModel ] ] :
""" Model wrapper that returns outputs formatted to match the given schema.
Args :
schema : The output schema as a dict or a Pydantic class . If a Pydantic class
then the model output will be an object of that class . If a dict then
the model output will be a dict . With a Pydantic class the returned
attributes will be validated , whereas with a dict they will not be . If
` method ` is " function_calling " and ` schema ` is a dict , then the dict
must match the OpenAI function - calling spec .
include_raw : If False then only the parsed structured output is returned . If
an error occurs during model output parsing it will be raised . If True
then both the raw model response ( a BaseMessage ) and the parsed model
response will be returned . If an error occurs during output parsing it
will be caught and returned as well . The final output is always a dict
with keys " raw " , " parsed " , and " parsing_error " .
Returns :
A Runnable that takes any ChatModel input and returns as output :
If include_raw is True then a dict with keys :
raw : BaseMessage
parsed : Optional [ _DictOrPydantic ]
parsing_error : Optional [ BaseException ]
If include_raw is False then just _DictOrPydantic is returned ,
where _DictOrPydantic depends on the schema :
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class .
If schema is a dict then _DictOrPydantic is a dict .
Example : Function - calling , Pydantic schema ( method = " function_calling " , include_raw = False ) :
. . code - block : : python
from langchain_mistralai import ChatMistralAI
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
llm = ChatMistralAI ( model = " mistral-large-latest " , temperature = 0 )
structured_llm = llm . with_structured_output ( AnswerWithJustification )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example : Function - calling , Pydantic schema ( method = " function_calling " , include_raw = True ) :
. . code - block : : python
from langchain_mistralai import ChatMistralAI
from langchain_core . pydantic_v1 import BaseModel
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
llm = ChatMistralAI ( model = " mistral-large-latest " , temperature = 0 )
structured_llm = llm . with_structured_output ( AnswerWithJustification , include_raw = True )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example : Function - calling , dict schema ( method = " function_calling " , include_raw = False ) :
. . code - block : : python
from langchain_mistralai import ChatMistralAI
from langchain_core . pydantic_v1 import BaseModel
from langchain_core . utils . function_calling import convert_to_openai_tool
class AnswerWithJustification ( BaseModel ) :
''' An answer to the user question along with justification for the answer. '''
answer : str
justification : str
dict_schema = convert_to_openai_tool ( AnswerWithJustification )
llm = ChatMistralAI ( model = " mistral-large-latest " , temperature = 0 )
structured_llm = llm . with_structured_output ( dict_schema )
structured_llm . invoke ( " What weighs more a pound of bricks or a pound of feathers " )
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs :
raise ValueError ( f " Received unsupported arguments { kwargs } " )
is_pydantic_schema = isinstance ( schema , type ) and issubclass ( schema , BaseModel )
llm = self . bind_tools ( [ schema ] , tool_choice = " any " )
if is_pydantic_schema :
output_parser : OutputParserLike = PydanticToolsParser (
tools = [ schema ] , first_tool_only = True
)
else :
key_name = convert_to_openai_tool ( schema ) [ " function " ] [ " name " ]
output_parser = JsonOutputKeyToolsParser (
key_name = key_name , first_tool_only = True
)
if include_raw :
parser_assign = RunnablePassthrough . assign (
parsed = itemgetter ( " raw " ) | output_parser , parsing_error = lambda _ : None
)
parser_none = RunnablePassthrough . assign ( parsed = lambda _ : None )
parser_with_fallback = parser_assign . with_fallbacks (
[ parser_none ] , exception_key = " parsing_error "
)
return RunnableMap ( raw = llm ) | parser_with_fallback
else :
return llm | output_parser
2023-12-19 15:34:19 +00:00
@property
def _identifying_params ( self ) - > Dict [ str , Any ] :
""" Get the identifying parameters. """
return self . _default_params
@property
def _llm_type ( self ) - > str :
""" Return type of chat model. """
return " mistralai-chat "
@property
def lc_secrets ( self ) - > Dict [ str , str ] :
return { " mistral_api_key " : " MISTRAL_API_KEY " }
@classmethod
def is_lc_serializable ( cls ) - > bool :
""" Return whether this model can be serialized by Langchain. """
return True
@classmethod
def get_lc_namespace ( cls ) - > List [ str ] :
""" Get the namespace of the langchain object. """
return [ " langchain " , " chat_models " , " mistralai " ]