diff --git a/libs/langchain/langchain/schema/agent.py b/libs/langchain/langchain/schema/agent.py index a6d269b7e0..447dcbd441 100644 --- a/libs/langchain/langchain/schema/agent.py +++ b/libs/langchain/langchain/schema/agent.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Sequence, Union +from typing import Any, Literal, Sequence, Union from langchain.load.serializable import Serializable from langchain.schema.messages import BaseMessage @@ -21,10 +21,12 @@ class AgentAction(Serializable): thoughts. This is useful when (tool, tool_input) does not contain full information about the LLM prediction (for example, any `thought` before the tool/tool_input).""" + type: Literal["AgentAction"] = "AgentAction" def __init__( self, tool: str, tool_input: Union[str, dict], log: str, **kwargs: Any ): + """Override init to support instantiation by position for backward compat.""" super().__init__(tool=tool, tool_input=tool_input, log=log, **kwargs) @classmethod @@ -42,6 +44,10 @@ class AgentActionMessageLog(AgentAction): prediction, and you need that LLM prediction (for future agent iteration). Compared to `log`, this is useful when the underlying LLM is a ChatModel (and therefore returns messages rather than a string).""" + # Ignoring type because we're overriding the type from AgentAction. + # And this is the correct thing to do in this case. + # The type literal is used for serialization purposes. + type: Literal["AgentActionMessageLog"] = "AgentActionMessageLog" # type: ignore class AgentFinish(Serializable): @@ -56,8 +62,10 @@ class AgentFinish(Serializable): `Final Answer: 2` you may want to just return `2` as a return value, but pass along the full string as a `log` (for debugging or observability purposes). """ + type: Literal["AgentFinish"] = "AgentFinish" def __init__(self, return_values: dict, log: str, **kwargs: Any): + """Override init to support instantiation by position for backward compat.""" super().__init__(return_values=return_values, log=log, **kwargs) @classmethod diff --git a/libs/langchain/tests/unit_tests/test_schema.py b/libs/langchain/tests/unit_tests/test_schema.py index 629f5bcbeb..4077d8800e 100644 --- a/libs/langchain/tests/unit_tests/test_schema.py +++ b/libs/langchain/tests/unit_tests/test_schema.py @@ -6,7 +6,8 @@ from typing import Union from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValueConcrete from langchain.pydantic_v1 import BaseModel -from langchain.schema import Document +from langchain.schema import AgentAction, AgentFinish, Document +from langchain.schema.agent import AgentActionMessageLog from langchain.schema.messages import ( AIMessage, AIMessageChunk, @@ -104,6 +105,9 @@ def test_serialization_of_wellknown_objects() -> None: AIMessageChunk, StringPromptValue, ChatPromptValueConcrete, + AgentFinish, + AgentAction, + AgentActionMessageLog, ] lc_objects = [ @@ -132,6 +136,14 @@ def test_serialization_of_wellknown_objects() -> None: StringPromptValue(text="hello"), ChatPromptValueConcrete(messages=[HumanMessage(content="human")]), Document(page_content="hello"), + AgentFinish(return_values={}, log=""), + AgentAction(tool="tool", tool_input="input", log=""), + AgentActionMessageLog( + tool="tool", + tool_input="input", + log="", + message_log=[HumanMessage(content="human")], + ), ] for lc_object in lc_objects: