Add `type` to Agent actions (#11682)

Add `type` to agent actions.
pull/11400/head
Eugene Yurtsev 10 months ago committed by GitHub
parent c14a8df2ee
commit 17b5090c18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Sequence, Union from typing import Any, Literal, Sequence, Union
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
@ -21,10 +21,12 @@ class AgentAction(Serializable):
thoughts. This is useful when (tool, tool_input) does not contain thoughts. This is useful when (tool, tool_input) does not contain
full information about the LLM prediction (for example, any `thought` full information about the LLM prediction (for example, any `thought`
before the tool/tool_input).""" before the tool/tool_input)."""
type: Literal["AgentAction"] = "AgentAction"
def __init__( def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
): ):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
@classmethod @classmethod
@ -42,6 +44,10 @@ class AgentActionMessageLog(AgentAction):
prediction, and you need that LLM prediction (for future agent iteration). prediction, and you need that LLM prediction (for future agent iteration).
Compared to `log`, this is useful when the underlying LLM is a Compared to `log`, this is useful when the underlying LLM is a
ChatModel (and therefore returns messages rather than a string).""" ChatModel (and therefore returns messages rather than a string)."""
# Ignoring type because we're overriding the type from AgentAction.
# And this is the correct thing to do in this case.
# The type literal is used for serialization purposes.
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
class AgentFinish(Serializable): class AgentFinish(Serializable):
@ -56,8 +62,10 @@ class AgentFinish(Serializable):
`Final Answer: 2` you may want to just return `2` as a return value, but pass `Final Answer: 2` you may want to just return `2` as a return value, but pass
along the full string as a `log` (for debugging or observability purposes). along the full string as a `log` (for debugging or observability purposes).
""" """
type: Literal["AgentFinish"] = "AgentFinish"
def __init__(self, return_values: dict, log: str, **kwargs: Any): def __init__(self, return_values: dict, log: str, **kwargs: Any):
"""Override init to support instantiation by position for backward compat."""
super().__init__(return_values=return_values, log=log, **kwargs) super().__init__(return_values=return_values, log=log, **kwargs)
@classmethod @classmethod

@ -6,7 +6,8 @@ from typing import Union
from langchain.prompts.base import StringPromptValue from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete from langchain.prompts.chat import ChatPromptValueConcrete
from langchain.pydantic_v1 import BaseModel from langchain.pydantic_v1 import BaseModel
from langchain.schema import Document from langchain.schema import AgentAction, AgentFinish, Document
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -104,6 +105,9 @@ def test_serialization_of_wellknown_objects() -> None:
AIMessageChunk, AIMessageChunk,
StringPromptValue, StringPromptValue,
ChatPromptValueConcrete, ChatPromptValueConcrete,
AgentFinish,
AgentAction,
AgentActionMessageLog,
] ]
lc_objects = [ lc_objects = [
@ -132,6 +136,14 @@ def test_serialization_of_wellknown_objects() -> None:
StringPromptValue(text="hello"), StringPromptValue(text="hello"),
ChatPromptValueConcrete(messages=[HumanMessage(content="human")]), ChatPromptValueConcrete(messages=[HumanMessage(content="human")]),
Document(page_content="hello"), Document(page_content="hello"),
AgentFinish(return_values={}, log=""),
AgentAction(tool="tool", tool_input="input", log=""),
AgentActionMessageLog(
tool="tool",
tool_input="input",
log="",
message_log=[HumanMessage(content="human")],
),
] ]
for lc_object in lc_objects: for lc_object in lc_objects:

Loading…
Cancel
Save