Add `type` to Agent actions (#11682)

Add `type` to agent actions.
pull/11400/head
Eugene Yurtsev 9 months ago committed by GitHub
parent c14a8df2ee
commit 17b5090c18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Sequence, Union
from typing import Any, Literal, Sequence, Union
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
@ -21,10 +21,12 @@ class AgentAction(Serializable):
thoughts. This is useful when (tool, tool_input) does not contain
full information about the LLM prediction (for example, any `thought`
before the tool/tool_input)."""
type: Literal["AgentAction"] = "AgentAction"
def __init__(
self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any
):
"""Override init to support instantiation by position for backward compat."""
super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs)
@classmethod
@ -42,6 +44,10 @@ class AgentActionMessageLog(AgentAction):
prediction, and you need that LLM prediction (for future agent iteration).
Compared to `log`, this is useful when the underlying LLM is a
ChatModel (and therefore returns messages rather than a string)."""
# Ignoring type because we're overriding the type from AgentAction.
# And this is the correct thing to do in this case.
# The type literal is used for serialization purposes.
type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore
class AgentFinish(Serializable):
@ -56,8 +62,10 @@ class AgentFinish(Serializable):
`Final Answer: 2` you may want to just return `2` as a return value, but pass
along the full string as a `log` (for debugging or observability purposes).
"""
type: Literal["AgentFinish"] = "AgentFinish"
def __init__(self, return_values: dict, log: str, **kwargs: Any):
"""Override init to support instantiation by position for backward compat."""
super().__init__(return_values=return_values, log=log, **kwargs)
@classmethod

@ -6,7 +6,8 @@ from typing import Union
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
from langchain.pydantic_v1 import BaseModel
from langchain.schema import Document
from langchain.schema import AgentAction, AgentFinish, Document
from langchain.schema.agent import AgentActionMessageLog
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
@ -104,6 +105,9 @@ def test_serialization_of_wellknown_objects() -> None:
AIMessageChunk,
StringPromptValue,
ChatPromptValueConcrete,
AgentFinish,
AgentAction,
AgentActionMessageLog,
]
lc_objects = [
@ -132,6 +136,14 @@ def test_serialization_of_wellknown_objects() -> None:
StringPromptValue(text="hello"),
ChatPromptValueConcrete(messages=[HumanMessage(content="human")]),
Document(page_content="hello"),
AgentFinish(return_values={}, log=""),
AgentAction(tool="tool", tool_input="input", log=""),
AgentActionMessageLog(
tool="tool",
tool_input="input",
log="",
message_log=[HumanMessage(content="human")],
),
]
for lc_object in lc_objects:

Loading…
Cancel
Save