From 3bfe7cf467a26055f1dae9b9fd8e2ab70574d7b0 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 1 Jul 2023 10:39:19 -0700 Subject: [PATCH] Harrison/split schema dir (#7025) should be no functional changes also keep __init__ exposing a lot for backwards compat --------- Co-authored-by: Dev 2049 Co-authored-by: Bagatur --- langchain/agents/agent.py | 2 +- .../agents/agent_toolkits/pandas/base.py | 2 +- .../agents/agent_toolkits/python/base.py | 2 +- langchain/agents/agent_toolkits/sql/base.py | 2 +- langchain/agents/conversational_chat/base.py | 4 +- .../agents/openai_functions_agent/base.py | 4 +- .../openai_functions_multi_agent/base.py | 4 +- langchain/base_language.py | 3 +- langchain/callbacks/base.py | 5 +- langchain/callbacks/manager.py | 3 +- langchain/callbacks/promptlayer_callback.py | 6 +- langchain/callbacks/tracers/langchain.py | 2 +- langchain/callbacks/tracers/langchain_v1.py | 2 +- .../chains/conversational_retrieval/base.py | 3 +- .../openai_functions/citation_fuzzy_match.py | 2 +- .../openai_functions/qa_with_structure.py | 3 +- langchain/chat_models/anthropic.py | 6 +- langchain/chat_models/base.py | 4 +- langchain/chat_models/fake.py | 2 +- langchain/chat_models/google_palm.py | 6 +- langchain/chat_models/openai.py | 6 +- langchain/chat_models/promptlayer_openai.py | 3 +- langchain/chat_models/vertexai.py | 6 +- langchain/client/runner_utils.py | 6 +- .../agents/trajectory_eval_prompt.py | 4 +- .../autonomous_agents/autogpt/agent.py | 4 +- .../autonomous_agents/autogpt/prompt.py | 2 +- .../plan_and_execute/planners/chat_planner.py | 2 +- langchain/llms/base.py | 4 +- langchain/memory/buffer.py | 2 +- langchain/memory/buffer_window.py | 2 +- .../chat_message_histories/cassandra.py | 4 +- .../chat_message_histories/cosmos_db.py | 4 +- .../memory/chat_message_histories/dynamodb.py | 2 + .../memory/chat_message_histories/file.py | 4 +- .../chat_message_histories/firestore.py | 4 +- .../chat_message_histories/in_memory.py | 2 +- .../memory/chat_message_histories/momento.py | 4 +- .../memory/chat_message_histories/mongodb.py | 4 +- .../memory/chat_message_histories/postgres.py | 4 +- .../memory/chat_message_histories/redis.py | 4 +- .../memory/chat_message_histories/sql.py | 4 +- .../memory/chat_message_histories/zep.py | 4 +- langchain/memory/entity.py | 2 +- langchain/memory/kg.py | 6 +- langchain/memory/motorhead_memory.py | 2 +- langchain/memory/summary.py | 4 +- langchain/memory/summary_buffer.py | 2 +- langchain/memory/token_buffer.py | 2 +- langchain/memory/utils.py | 2 +- langchain/prompts/base.py | 3 +- langchain/prompts/chat.py | 6 +- langchain/schema.py | 886 ------------------ langchain/schema/__init__.py | 67 ++ langchain/schema/agent.py | 25 + langchain/schema/document.py | 82 ++ langchain/schema/memory.py | 121 +++ langchain/schema/messages.py | 183 ++++ langchain/schema/output.py | 118 +++ langchain/schema/output_parser.py | 172 ++++ langchain/schema/prompt.py | 23 + langchain/schema/retriever.py | 191 ++++ .../chat_models/test_anthropic.py | 4 +- .../chat_models/test_google_palm.py | 4 +- .../chat_models/test_openai.py | 4 +- .../chat_models/test_promptlayer_openai.py | 4 +- .../chat_models/test_vertexai.py | 6 +- .../memory/test_cassandra.py | 5 +- .../memory/test_cosmos_db.py | 2 +- .../memory/test_firestore.py | 2 +- .../integration_tests/memory/test_momento.py | 2 +- .../integration_tests/memory/test_mongodb.py | 2 +- tests/integration_tests/memory/test_redis.py | 2 +- .../callbacks/fake_callback_handler.py | 2 +- .../callbacks/tracers/test_base_tracer.py | 3 +- .../callbacks/tracers/test_langchain_v1.py | 3 +- .../chat_models/test_google_palm.py | 6 +- tests/unit_tests/chat_models/test_openai.py | 4 +- tests/unit_tests/llms/fake_chat_model.py | 3 +- tests/unit_tests/llms/test_callbacks.py | 2 +- .../chat_message_histories/test_file.py | 2 +- .../memory/chat_message_histories/test_sql.py | 2 +- .../memory/chat_message_histories/test_zep.py | 2 +- tests/unit_tests/prompts/test_chat.py | 2 +- tests/unit_tests/test_cache.py | 4 +- tests/unit_tests/test_schema.py | 2 +- 86 files changed, 1095 insertions(+), 1028 deletions(-) delete mode 100644 langchain/schema.py create mode 100644 langchain/schema/__init__.py create mode 100644 langchain/schema/agent.py create mode 100644 langchain/schema/document.py create mode 100644 langchain/schema/memory.py create mode 100644 langchain/schema/messages.py create mode 100644 langchain/schema/output.py create mode 100644 langchain/schema/output_parser.py create mode 100644 langchain/schema/prompt.py create mode 100644 langchain/schema/retriever.py diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 1f26ff2920..9c6c6b0c22 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -32,10 +32,10 @@ from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( AgentAction, AgentFinish, - BaseMessage, BaseOutputParser, OutputParserException, ) +from langchain.schema.messages import BaseMessage from langchain.tools.base import BaseTool from langchain.utilities.asyncio import asyncio_timeout diff --git a/langchain/agents/agent_toolkits/pandas/base.py b/langchain/agents/agent_toolkits/pandas/base.py index 03eddbc897..a695dc8eed 100644 --- a/langchain/agents/agent_toolkits/pandas/base.py +++ b/langchain/agents/agent_toolkits/pandas/base.py @@ -20,7 +20,7 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import SystemMessage +from langchain.schema.messages import SystemMessage from langchain.tools.python.tool import PythonAstREPLTool diff --git a/langchain/agents/agent_toolkits/python/base.py b/langchain/agents/agent_toolkits/python/base.py index e86c6ca647..0cb8d7bd6b 100644 --- a/langchain/agents/agent_toolkits/python/base.py +++ b/langchain/agents/agent_toolkits/python/base.py @@ -10,7 +10,7 @@ from langchain.agents.types import AgentType from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain -from langchain.schema import SystemMessage +from langchain.schema.messages import SystemMessage from langchain.tools.python.tool import PythonREPLTool diff --git a/langchain/agents/agent_toolkits/sql/base.py b/langchain/agents/agent_toolkits/sql/base.py index 82fe07cca9..790fecc985 100644 --- a/langchain/agents/agent_toolkits/sql/base.py +++ b/langchain/agents/agent_toolkits/sql/base.py @@ -20,7 +20,7 @@ from langchain.prompts.chat import ( HumanMessagePromptTemplate, MessagesPlaceholder, ) -from langchain.schema import AIMessage, SystemMessage +from langchain.schema.messages import AIMessage, SystemMessage def create_sql_agent( diff --git a/langchain/agents/conversational_chat/base.py b/langchain/agents/conversational_chat/base.py index 0e103720c4..c399a10a65 100644 --- a/langchain/agents/conversational_chat/base.py +++ b/langchain/agents/conversational_chat/base.py @@ -25,11 +25,9 @@ from langchain.prompts.chat import ( ) from langchain.schema import ( AgentAction, - AIMessage, - BaseMessage, BaseOutputParser, - HumanMessage, ) +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from langchain.tools.base import BaseTool diff --git a/langchain/agents/openai_functions_agent/base.py b/langchain/agents/openai_functions_agent/base.py index 15c35b4de9..619820d461 100644 --- a/langchain/agents/openai_functions_agent/base.py +++ b/langchain/agents/openai_functions_agent/base.py @@ -21,10 +21,12 @@ from langchain.prompts.chat import ( from langchain.schema import ( AgentAction, AgentFinish, + OutputParserException, +) +from langchain.schema.messages import ( AIMessage, BaseMessage, FunctionMessage, - OutputParserException, SystemMessage, ) from langchain.tools import BaseTool diff --git a/langchain/agents/openai_functions_multi_agent/base.py b/langchain/agents/openai_functions_multi_agent/base.py index 4b4cbbbc6c..6ab0038eb0 100644 --- a/langchain/agents/openai_functions_multi_agent/base.py +++ b/langchain/agents/openai_functions_multi_agent/base.py @@ -21,10 +21,12 @@ from langchain.prompts.chat import ( from langchain.schema import ( AgentAction, AgentFinish, + OutputParserException, +) +from langchain.schema.messages import ( AIMessage, BaseMessage, FunctionMessage, - OutputParserException, SystemMessage, ) from langchain.tools import BaseTool diff --git a/langchain/base_language.py b/langchain/base_language.py index f02e43d613..4059dc8c47 100644 --- a/langchain/base_language.py +++ b/langchain/base_language.py @@ -6,7 +6,8 @@ from typing import Any, List, Optional, Sequence, Set from langchain.callbacks.manager import Callbacks from langchain.load.serializable import Serializable -from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string +from langchain.schema import LLMResult, PromptValue +from langchain.schema.messages import BaseMessage, get_buffer_string def _get_token_ids_default_method(text: str) -> List[int]: diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 6092970d01..4d9d4c75c6 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -4,7 +4,10 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Sequence, Union from uuid import UUID -from langchain.schema import AgentAction, AgentFinish, BaseMessage, Document, LLMResult +from langchain.schema.agent import AgentAction, AgentFinish +from langchain.schema.document import Document +from langchain.schema.messages import BaseMessage +from langchain.schema.output import LLMResult class RetrieverManagerMixin: diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 568d11a9a8..00d87af384 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -41,11 +41,10 @@ from langchain.callbacks.tracers.wandb import WandbTracer from langchain.schema import ( AgentAction, AgentFinish, - BaseMessage, Document, LLMResult, - get_buffer_string, ) +from langchain.schema.messages import BaseMessage, get_buffer_string logger = logging.getLogger(__name__) Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] diff --git a/langchain/callbacks/promptlayer_callback.py b/langchain/callbacks/promptlayer_callback.py index c6a4f1cc47..bd93d70879 100644 --- a/langchain/callbacks/promptlayer_callback.py +++ b/langchain/callbacks/promptlayer_callback.py @@ -7,12 +7,14 @@ from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import ( + ChatGeneration, + LLMResult, +) +from langchain.schema.messages import ( AIMessage, BaseMessage, - ChatGeneration, ChatMessage, HumanMessage, - LLMResult, SystemMessage, ) diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index da8e20512c..1c0095dd4d 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -17,7 +17,7 @@ from langchain.callbacks.tracers.schemas import ( TracerSession, ) from langchain.env import get_runtime_environment -from langchain.schema import BaseMessage, messages_to_dict +from langchain.schema.messages import BaseMessage, messages_to_dict logger = logging.getLogger(__name__) _LOGGED = set() diff --git a/langchain/callbacks/tracers/langchain_v1.py b/langchain/callbacks/tracers/langchain_v1.py index 171983c301..d7825ec245 100644 --- a/langchain/callbacks/tracers/langchain_v1.py +++ b/langchain/callbacks/tracers/langchain_v1.py @@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import ( TracerSessionV1, TracerSessionV1Base, ) -from langchain.schema import get_buffer_string +from langchain.schema.messages import get_buffer_string from langchain.utils import raise_for_status_with_text diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 4cd6fefc66..d369fb31a4 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -22,7 +22,8 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_ from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseMessage, BaseRetriever, Document +from langchain.schema import BaseRetriever, Document +from langchain.schema.messages import BaseMessage from langchain.vectorstores.base import VectorStore # Depending on the memory type and configuration, the chat history format may differ. diff --git a/langchain/chains/openai_functions/citation_fuzzy_match.py b/langchain/chains/openai_functions/citation_fuzzy_match.py index 3e9b8a3bee..ac812c4904 100644 --- a/langchain/chains/openai_functions/citation_fuzzy_match.py +++ b/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -9,7 +9,7 @@ from langchain.output_parsers.openai_functions import ( PydanticOutputFunctionsParser, ) from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.schema import HumanMessage, SystemMessage +from langchain.schema.messages import HumanMessage, SystemMessage class FactWithEvidence(BaseModel): diff --git a/langchain/chains/openai_functions/qa_with_structure.py b/langchain/chains/openai_functions/qa_with_structure.py index 7f022cd1ee..a3c0584db8 100644 --- a/langchain/chains/openai_functions/qa_with_structure.py +++ b/langchain/chains/openai_functions/qa_with_structure.py @@ -11,7 +11,8 @@ from langchain.output_parsers.openai_functions import ( ) from langchain.prompts import PromptTemplate from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.schema import BaseLLMOutputParser, HumanMessage, SystemMessage +from langchain.schema import BaseLLMOutputParser +from langchain.schema.messages import HumanMessage, SystemMessage class AnswerWithSources(BaseModel): diff --git a/langchain/chat_models/anthropic.py b/langchain/chat_models/anthropic.py index edb036526b..4fa492def4 100644 --- a/langchain/chat_models/anthropic.py +++ b/langchain/chat_models/anthropic.py @@ -7,11 +7,13 @@ from langchain.callbacks.manager import ( from langchain.chat_models.base import BaseChatModel from langchain.llms.anthropic import _AnthropicCommon from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( AIMessage, BaseMessage, - ChatGeneration, ChatMessage, - ChatResult, HumanMessage, SystemMessage, ) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 51b3a5a9bb..132971da42 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -19,15 +19,13 @@ from langchain.callbacks.manager import ( ) from langchain.load.dump import dumpd, dumps from langchain.schema import ( - AIMessage, - BaseMessage, ChatGeneration, ChatResult, - HumanMessage, LLMResult, PromptValue, RunInfo, ) +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage def _get_verbosity() -> bool: diff --git a/langchain/chat_models/fake.py b/langchain/chat_models/fake.py index 0149b1ce72..a974fe592a 100644 --- a/langchain/chat_models/fake.py +++ b/langchain/chat_models/fake.py @@ -3,7 +3,7 @@ from typing import Any, List, Mapping, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel -from langchain.schema import BaseMessage +from langchain.schema.messages import BaseMessage class FakeListChatModel(SimpleChatModel): diff --git a/langchain/chat_models/google_palm.py b/langchain/chat_models/google_palm.py index cf27c2bc46..699b5fab4d 100644 --- a/langchain/chat_models/google_palm.py +++ b/langchain/chat_models/google_palm.py @@ -19,11 +19,13 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.base import BaseChatModel from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( AIMessage, BaseMessage, - ChatGeneration, ChatMessage, - ChatResult, HumanMessage, SystemMessage, ) diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index f1725b8331..af1c6d3e1e 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -30,11 +30,13 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.base import BaseChatModel from langchain.schema import ( + ChatGeneration, + ChatResult, +) +from langchain.schema.messages import ( AIMessage, BaseMessage, - ChatGeneration, ChatMessage, - ChatResult, FunctionMessage, HumanMessage, SystemMessage, diff --git a/langchain/chat_models/promptlayer_openai.py b/langchain/chat_models/promptlayer_openai.py index ccb13b05b2..3889b53558 100644 --- a/langchain/chat_models/promptlayer_openai.py +++ b/langchain/chat_models/promptlayer_openai.py @@ -7,7 +7,8 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.chat_models import ChatOpenAI -from langchain.schema import BaseMessage, ChatResult +from langchain.schema import ChatResult +from langchain.schema.messages import BaseMessage class PromptLayerChatOpenAI(ChatOpenAI): diff --git a/langchain/chat_models/vertexai.py b/langchain/chat_models/vertexai.py index 5770da6e11..4b090be66a 100644 --- a/langchain/chat_models/vertexai.py +++ b/langchain/chat_models/vertexai.py @@ -11,10 +11,12 @@ from langchain.callbacks.manager import ( from langchain.chat_models.base import BaseChatModel from langchain.llms.vertexai import _VertexAICommon, is_codey_model from langchain.schema import ( - AIMessage, - BaseMessage, ChatGeneration, ChatResult, +) +from langchain.schema.messages import ( + AIMessage, + BaseMessage, HumanMessage, SystemMessage, ) diff --git a/langchain/client/runner_utils.py b/langchain/client/runner_utils.py index 10f380b445..73aa3c2b09 100644 --- a/langchain/client/runner_utils.py +++ b/langchain/client/runner_utils.py @@ -31,10 +31,12 @@ from langchain.chains.base import Chain from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM from langchain.schema import ( - BaseMessage, ChatResult, - HumanMessage, LLMResult, +) +from langchain.schema.messages import ( + BaseMessage, + HumanMessage, get_buffer_string, messages_from_dict, ) diff --git a/langchain/evaluation/agents/trajectory_eval_prompt.py b/langchain/evaluation/agents/trajectory_eval_prompt.py index 422f66ac8a..5f1f86eacc 100644 --- a/langchain/evaluation/agents/trajectory_eval_prompt.py +++ b/langchain/evaluation/agents/trajectory_eval_prompt.py @@ -1,8 +1,6 @@ """Prompt for trajectory evaluation chain.""" # flake8: noqa -from langchain.schema import AIMessage -from langchain.schema import HumanMessage -from langchain.schema import SystemMessage +from langchain.schema.messages import HumanMessage, AIMessage, SystemMessage from langchain.prompts.chat import ( ChatPromptTemplate, diff --git a/langchain/experimental/autonomous_agents/autogpt/agent.py b/langchain/experimental/autonomous_agents/autogpt/agent.py index 6139d94c4e..4d12f086d8 100644 --- a/langchain/experimental/autonomous_agents/autogpt/agent.py +++ b/langchain/experimental/autonomous_agents/autogpt/agent.py @@ -16,12 +16,10 @@ from langchain.experimental.autonomous_agents.autogpt.prompt_generator import ( ) from langchain.memory import ChatMessageHistory from langchain.schema import ( - AIMessage, BaseChatMessageHistory, Document, - HumanMessage, - SystemMessage, ) +from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage from langchain.tools.base import BaseTool from langchain.tools.human.tool import HumanInputRun from langchain.vectorstores.base import VectorStoreRetriever diff --git a/langchain/experimental/autonomous_agents/autogpt/prompt.py b/langchain/experimental/autonomous_agents/autogpt/prompt.py index d1f3a9b74a..f6645a1dba 100644 --- a/langchain/experimental/autonomous_agents/autogpt/prompt.py +++ b/langchain/experimental/autonomous_agents/autogpt/prompt.py @@ -7,7 +7,7 @@ from langchain.experimental.autonomous_agents.autogpt.prompt_generator import ge from langchain.prompts.chat import ( BaseChatPromptTemplate, ) -from langchain.schema import BaseMessage, HumanMessage, SystemMessage +from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage from langchain.tools.base import BaseTool from langchain.vectorstores.base import VectorStoreRetriever diff --git a/langchain/experimental/plan_and_execute/planners/chat_planner.py b/langchain/experimental/plan_and_execute/planners/chat_planner.py index 4b6e9da399..1fb879a68f 100644 --- a/langchain/experimental/plan_and_execute/planners/chat_planner.py +++ b/langchain/experimental/plan_and_execute/planners/chat_planner.py @@ -9,7 +9,7 @@ from langchain.experimental.plan_and_execute.schema import ( Step, ) from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate -from langchain.schema import SystemMessage +from langchain.schema.messages import SystemMessage SYSTEM_PROMPT = ( "Let's first understand the problem and devise a plan to solve the problem." diff --git a/langchain/llms/base.py b/langchain/llms/base.py index a25d65f659..9a15e5b173 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -22,14 +22,12 @@ from langchain.callbacks.manager import ( ) from langchain.load.dump import dumpd from langchain.schema import ( - AIMessage, - BaseMessage, Generation, LLMResult, PromptValue, RunInfo, - get_buffer_string, ) +from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string def _get_verbosity() -> bool: diff --git a/langchain/memory/buffer.py b/langchain/memory/buffer.py index f3623aaf21..50b1468b64 100644 --- a/langchain/memory/buffer.py +++ b/langchain/memory/buffer.py @@ -4,7 +4,7 @@ from pydantic import root_validator from langchain.memory.chat_memory import BaseChatMemory, BaseMemory from langchain.memory.utils import get_prompt_input_key -from langchain.schema import get_buffer_string +from langchain.schema.messages import get_buffer_string class ConversationBufferMemory(BaseChatMemory): diff --git a/langchain/memory/buffer_window.py b/langchain/memory/buffer_window.py index c9c0178736..34d70f59d3 100644 --- a/langchain/memory/buffer_window.py +++ b/langchain/memory/buffer_window.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import BaseMessage, get_buffer_string +from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationBufferWindowMemory(BaseChatMemory): diff --git a/langchain/memory/chat_message_histories/cassandra.py b/langchain/memory/chat_message_histories/cassandra.py index ec6c17528e..6646e55ed2 100644 --- a/langchain/memory/chat_message_histories/cassandra.py +++ b/langchain/memory/chat_message_histories/cassandra.py @@ -10,10 +10,8 @@ if typing.TYPE_CHECKING: from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - _message_to_dict, - messages_from_dict, ) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict DEFAULT_TABLE_NAME = "message_store" DEFAULT_TTL_SECONDS = None diff --git a/langchain/memory/chat_message_histories/cosmos_db.py b/langchain/memory/chat_message_histories/cosmos_db.py index 5318c805b7..e1e5ad3261 100644 --- a/langchain/memory/chat_message_histories/cosmos_db.py +++ b/langchain/memory/chat_message_histories/cosmos_db.py @@ -7,10 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Optional, Type from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - messages_from_dict, - messages_to_dict, ) +from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict logger = logging.getLogger(__name__) diff --git a/langchain/memory/chat_message_histories/dynamodb.py b/langchain/memory/chat_message_histories/dynamodb.py index ef9699221b..219ca9844a 100644 --- a/langchain/memory/chat_message_histories/dynamodb.py +++ b/langchain/memory/chat_message_histories/dynamodb.py @@ -3,6 +3,8 @@ from typing import List, Optional from langchain.schema import ( BaseChatMessageHistory, +) +from langchain.schema.messages import ( BaseMessage, _message_to_dict, messages_from_dict, diff --git a/langchain/memory/chat_message_histories/file.py b/langchain/memory/chat_message_histories/file.py index 0fbbf1e706..912ff740cb 100644 --- a/langchain/memory/chat_message_histories/file.py +++ b/langchain/memory/chat_message_histories/file.py @@ -5,10 +5,8 @@ from typing import List from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - messages_from_dict, - messages_to_dict, ) +from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict logger = logging.getLogger(__name__) diff --git a/langchain/memory/chat_message_histories/firestore.py b/langchain/memory/chat_message_histories/firestore.py index 3e325682b9..fdfb4e5669 100644 --- a/langchain/memory/chat_message_histories/firestore.py +++ b/langchain/memory/chat_message_histories/firestore.py @@ -6,10 +6,8 @@ from typing import TYPE_CHECKING, List, Optional from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - messages_from_dict, - messages_to_dict, ) +from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict logger = logging.getLogger(__name__) diff --git a/langchain/memory/chat_message_histories/in_memory.py b/langchain/memory/chat_message_histories/in_memory.py index bcb60d2e69..e3fa603f88 100644 --- a/langchain/memory/chat_message_histories/in_memory.py +++ b/langchain/memory/chat_message_histories/in_memory.py @@ -4,8 +4,8 @@ from pydantic import BaseModel from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, ) +from langchain.schema.messages import BaseMessage class ChatMessageHistory(BaseChatMessageHistory, BaseModel): diff --git a/langchain/memory/chat_message_histories/momento.py b/langchain/memory/chat_message_histories/momento.py index 885fe16b63..5fefc790ef 100644 --- a/langchain/memory/chat_message_histories/momento.py +++ b/langchain/memory/chat_message_histories/momento.py @@ -6,10 +6,8 @@ from typing import TYPE_CHECKING, Any, Optional from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - _message_to_dict, - messages_from_dict, ) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict from langchain.utils import get_from_env if TYPE_CHECKING: diff --git a/langchain/memory/chat_message_histories/mongodb.py b/langchain/memory/chat_message_histories/mongodb.py index a6965af3c0..5cc3af8dbe 100644 --- a/langchain/memory/chat_message_histories/mongodb.py +++ b/langchain/memory/chat_message_histories/mongodb.py @@ -4,10 +4,8 @@ from typing import List from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - _message_to_dict, - messages_from_dict, ) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict logger = logging.getLogger(__name__) diff --git a/langchain/memory/chat_message_histories/postgres.py b/langchain/memory/chat_message_histories/postgres.py index 3d06346d26..fb1d3b3488 100644 --- a/langchain/memory/chat_message_histories/postgres.py +++ b/langchain/memory/chat_message_histories/postgres.py @@ -4,10 +4,8 @@ from typing import List from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - _message_to_dict, - messages_from_dict, ) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict logger = logging.getLogger(__name__) diff --git a/langchain/memory/chat_message_histories/redis.py b/langchain/memory/chat_message_histories/redis.py index 33c965ebd8..442059b442 100644 --- a/langchain/memory/chat_message_histories/redis.py +++ b/langchain/memory/chat_message_histories/redis.py @@ -4,10 +4,8 @@ from typing import List, Optional from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - _message_to_dict, - messages_from_dict, ) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict logger = logging.getLogger(__name__) diff --git a/langchain/memory/chat_message_histories/sql.py b/langchain/memory/chat_message_histories/sql.py index 9520d159e3..85baf037e5 100644 --- a/langchain/memory/chat_message_histories/sql.py +++ b/langchain/memory/chat_message_histories/sql.py @@ -12,10 +12,8 @@ from sqlalchemy.orm import sessionmaker from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - _message_to_dict, - messages_from_dict, ) +from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict logger = logging.getLogger(__name__) diff --git a/langchain/memory/chat_message_histories/zep.py b/langchain/memory/chat_message_histories/zep.py index d124fe27e8..58dc26bdc4 100644 --- a/langchain/memory/chat_message_histories/zep.py +++ b/langchain/memory/chat_message_histories/zep.py @@ -4,11 +4,9 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional from langchain.schema import ( - AIMessage, BaseChatMessageHistory, - BaseMessage, - HumanMessage, ) +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage if TYPE_CHECKING: from zep_python import Memory, MemorySearchResult, Message, NotFoundError diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index 759da031e8..5c06bd197d 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -14,7 +14,7 @@ from langchain.memory.prompt import ( ) from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseMessage, get_buffer_string +from langchain.schema.messages import BaseMessage, get_buffer_string logger = logging.getLogger(__name__) diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index 2c71a33c44..fd071c066d 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -13,11 +13,7 @@ from langchain.memory.prompt import ( ) from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import ( - BaseMessage, - SystemMessage, - get_buffer_string, -) +from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string class ConversationKGMemory(BaseChatMemory): diff --git a/langchain/memory/motorhead_memory.py b/langchain/memory/motorhead_memory.py index 56576f15a2..f84ebe0f1f 100644 --- a/langchain/memory/motorhead_memory.py +++ b/langchain/memory/motorhead_memory.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import requests from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import get_buffer_string +from langchain.schema.messages import get_buffer_string MANAGED_URL = "https://api.getmetal.io/v1/motorhead" # LOCAL_URL = "http://localhost:8080" diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index c35bd70b93..1fd58196be 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -11,10 +11,8 @@ from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts.base import BasePromptTemplate from langchain.schema import ( BaseChatMessageHistory, - BaseMessage, - SystemMessage, - get_buffer_string, ) +from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string class SummarizerMixin(BaseModel): diff --git a/langchain/memory/summary_buffer.py b/langchain/memory/summary_buffer.py index 5e4c5b93ec..0b49797f59 100644 --- a/langchain/memory/summary_buffer.py +++ b/langchain/memory/summary_buffer.py @@ -4,7 +4,7 @@ from pydantic import root_validator from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.summary import SummarizerMixin -from langchain.schema import BaseMessage, get_buffer_string +from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): diff --git a/langchain/memory/token_buffer.py b/langchain/memory/token_buffer.py index c5e3c01b63..63b00007f4 100644 --- a/langchain/memory/token_buffer.py +++ b/langchain/memory/token_buffer.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List from langchain.base_language import BaseLanguageModel from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import BaseMessage, get_buffer_string +from langchain.schema.messages import BaseMessage, get_buffer_string class ConversationTokenBufferMemory(BaseChatMemory): diff --git a/langchain/memory/utils.py b/langchain/memory/utils.py index 4e1e7efb98..2706f1fc7e 100644 --- a/langchain/memory/utils.py +++ b/langchain/memory/utils.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List -from langchain.schema import get_buffer_string # noqa: 401 +from langchain.schema.messages import get_buffer_string # noqa: 401 def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str: diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index d8fa2dbdf2..6852a50558 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -11,7 +11,8 @@ from pydantic import Field, root_validator from langchain.formatting import formatter from langchain.load.serializable import Serializable -from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue +from langchain.schema import BaseOutputParser, PromptValue +from langchain.schema.messages import BaseMessage, HumanMessage def jinja2_formatter(template: str, **kwargs: Any) -> str: diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 7be1390b35..59cb44aa33 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -8,16 +8,18 @@ from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from pydantic import Field, root_validator from langchain.load.serializable import Serializable -from langchain.memory.buffer import get_buffer_string from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( + PromptValue, +) +from langchain.schema.messages import ( AIMessage, BaseMessage, ChatMessage, HumanMessage, - PromptValue, SystemMessage, + get_buffer_string, ) diff --git a/langchain/schema.py b/langchain/schema.py deleted file mode 100644 index 162f97e2f9..0000000000 --- a/langchain/schema.py +++ /dev/null @@ -1,886 +0,0 @@ -"""Common schema objects.""" -from __future__ import annotations - -import warnings -from abc import ABC, abstractmethod -from copy import deepcopy -from dataclasses import dataclass -from inspect import signature -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Generic, - List, - NamedTuple, - Optional, - Sequence, - TypeVar, - Union, -) -from uuid import UUID - -from pydantic import BaseModel, Field, root_validator - -from langchain.load.serializable import Serializable - -if TYPE_CHECKING: - from langchain.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, - Callbacks, - ) - -RUN_KEY = "__run" - - -def get_buffer_string( - messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" -) -> str: - """Convert sequence of Messages to strings and concatenate them into one string. - - Args: - messages: Messages to be converted to strings. - human_prefix: The prefix to prepend to contents of HumanMessages. - ai_prefix: THe prefix to prepend to contents of AIMessages. - - Returns: - A single string concatenation of all input messages. - - Example: - .. code-block:: python - - from langchain.schema import AIMessage, HumanMessage - - messages = [ - HumanMessage(content="Hi, how are you?"), - AIMessage(content="Good, how are you?"), - ] - get_buffer_string(messages) - # -> "Human: Hi, how are you?\nAI: Good, how are you?" - """ - string_messages = [] - for m in messages: - if isinstance(m, HumanMessage): - role = human_prefix - elif isinstance(m, AIMessage): - role = ai_prefix - elif isinstance(m, SystemMessage): - role = "System" - elif isinstance(m, FunctionMessage): - role = "Function" - elif isinstance(m, ChatMessage): - role = m.role - else: - raise ValueError(f"Got unsupported message type: {m}") - message = f"{role}: {m.content}" - if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: - message += f"{m.additional_kwargs['function_call']}" - string_messages.append(message) - - return "\n".join(string_messages) - - -@dataclass -class AgentAction: - """A full description of an action for an ActionAgent to execute.""" - - tool: str - """The name of the Tool to execute.""" - tool_input: Union[str, dict] - """The input to pass in to the Tool.""" - log: str - """Additional information to log about the action.""" - - -class AgentFinish(NamedTuple): - """The final return value of an ActionAgent.""" - - return_values: dict - """Dictionary of return values.""" - log: str - """Additional information to log about the return value""" - - -class Generation(Serializable): - """A single text generation output.""" - - text: str - """Generated text output.""" - - generation_info: Optional[Dict[str, Any]] = None - """Raw response from the provider. May include things like the - reason for finishing or token log probabilities. - """ - # TODO: add log probs as separate attribute - - @property - def lc_serializable(self) -> bool: - """Whether this class is LangChain serializable.""" - return True - - -class BaseMessage(Serializable): - """The base abstract Message class. - - Messages are the inputs and outputs of ChatModels. - """ - - content: str - """The string contents of the message.""" - - additional_kwargs: dict = Field(default_factory=dict) - """Any additional information.""" - - @property - @abstractmethod - def type(self) -> str: - """Type of the Message, used for serialization.""" - - @property - def lc_serializable(self) -> bool: - """Whether this class is LangChain serializable.""" - return True - - -class HumanMessage(BaseMessage): - """A Message from a human.""" - - example: bool = False - """Whether this Message is being passed in to the model as part of an example - conversation. - """ - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "human" - - -class AIMessage(BaseMessage): - """A Message from an AI.""" - - example: bool = False - """Whether this Message is being passed in to the model as part of an example - conversation. - """ - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "ai" - - -class SystemMessage(BaseMessage): - """A Message for priming AI behavior, usually passed in as the first of a sequence - of input messages. - """ - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "system" - - -class FunctionMessage(BaseMessage): - """A Message for passing the result of executing a function back to a model.""" - - name: str - """The name of the function that was executed.""" - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "function" - - -class ChatMessage(BaseMessage): - """A Message that can be assigned an arbitrary speaker (i.e. role).""" - - role: str - """The speaker / role of the Message.""" - - @property - def type(self) -> str: - """Type of the message, used for serialization.""" - return "chat" - - -def _message_to_dict(message: BaseMessage) -> dict: - return {"type": message.type, "data": message.dict()} - - -def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]: - """Convert a sequence of Messages to a list of dictionaries. - - Args: - messages: Sequence of messages (as BaseMessages) to convert. - - Returns: - List of messages as dicts. - """ - return [_message_to_dict(m) for m in messages] - - -def _message_from_dict(message: dict) -> BaseMessage: - _type = message["type"] - if _type == "human": - return HumanMessage(**message["data"]) - elif _type == "ai": - return AIMessage(**message["data"]) - elif _type == "system": - return SystemMessage(**message["data"]) - elif _type == "chat": - return ChatMessage(**message["data"]) - else: - raise ValueError(f"Got unexpected type: {_type}") - - -def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: - """Convert a sequence of messages from dicts to Message objects. - - Args: - messages: Sequence of messages (as dicts) to convert. - - Returns: - List of messages (BaseMessages). - """ - return [_message_from_dict(m) for m in messages] - - -class ChatGeneration(Generation): - """A single chat generation output.""" - - text: str = "" - """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" - message: BaseMessage - """The message output by the chat model.""" - - @root_validator - def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: - """Set the text attribute to be the contents of the message.""" - values["text"] = values["message"].content - return values - - -class RunInfo(BaseModel): - """Class that contains metadata for a single execution of a Chain or model.""" - - run_id: UUID - """A unique identifier for the model or chain run.""" - - -class ChatResult(BaseModel): - """Class that contains all results for a single chat model call.""" - - generations: List[ChatGeneration] - """List of the chat generations. This is a List because an input can have multiple - candidate generations. - """ - llm_output: Optional[dict] = None - """For arbitrary LLM provider specific output.""" - - -class LLMResult(BaseModel): - """Class that contains all results for a batched LLM call.""" - - generations: List[List[Generation]] - """List of generated outputs. This is a List[List[]] because - each input could have multiple candidate generations.""" - llm_output: Optional[dict] = None - """Arbitrary LLM provider-specific output.""" - run: Optional[List[RunInfo]] = None - """List of metadata info for model call for each input.""" - - def flatten(self) -> List[LLMResult]: - """Flatten generations into a single list. - - Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult - contains only a single Generation. If token usage information is available, - it is kept only for the LLMResult corresponding to the top-choice - Generation, to avoid over-counting of token usage downstream. - - Returns: - List of LLMResults where each returned LLMResult contains a single - Generation. - """ - llm_results = [] - for i, gen_list in enumerate(self.generations): - # Avoid double counting tokens in OpenAICallback - if i == 0: - llm_results.append( - LLMResult( - generations=[gen_list], - llm_output=self.llm_output, - ) - ) - else: - if self.llm_output is not None: - llm_output = deepcopy(self.llm_output) - llm_output["token_usage"] = dict() - else: - llm_output = None - llm_results.append( - LLMResult( - generations=[gen_list], - llm_output=llm_output, - ) - ) - return llm_results - - def __eq__(self, other: object) -> bool: - """Check for LLMResult equality by ignoring any metadata related to runs.""" - if not isinstance(other, LLMResult): - return NotImplemented - return ( - self.generations == other.generations - and self.llm_output == other.llm_output - ) - - -class PromptValue(Serializable, ABC): - """Base abstract class for inputs to any language model. - - PromptValues can be converted to both LLM (pure text-generation) inputs and - ChatModel inputs. - """ - - @abstractmethod - def to_string(self) -> str: - """Return prompt value as string.""" - - @abstractmethod - def to_messages(self) -> List[BaseMessage]: - """Return prompt as a list of Messages.""" - - -class BaseMemory(Serializable, ABC): - """Base abstract class for memory in Chains. - - Memory refers to state in Chains. Memory can be used to store information about - past executions of a Chain and inject that information into the inputs of - future executions of the Chain. For example, for conversational Chains Memory - can be used to store conversations and automatically add them to future model - prompts so that the model has the necessary context to respond coherently to - the latest input. - - Example: - .. code-block:: python - - class SimpleMemory(BaseMemory): - memories: Dict[str, Any] = dict() - - @property - def memory_variables(self) -> List[str]: - return list(self.memories.keys()) - - def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: - return self.memories - - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - pass - - def clear(self) -> None: - pass - """ # noqa: E501 - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @property - @abstractmethod - def memory_variables(self) -> List[str]: - """The string keys this memory class will add to chain inputs.""" - - @abstractmethod - def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - """Return key-value pairs given the text input to the chain.""" - - @abstractmethod - def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: - """Save the context of this chain run to memory.""" - - @abstractmethod - def clear(self) -> None: - """Clear memory contents.""" - - -class BaseChatMessageHistory(ABC): - """Abstract base class for storing chat message history. - - See `ChatMessageHistory` for default implementation. - - Example: - .. code-block:: python - - class FileChatMessageHistory(BaseChatMessageHistory): - storage_path: str - session_id: str - - @property - def messages(self): - with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: - messages = json.loads(f.read()) - return messages_from_dict(messages) - - def add_message(self, message: BaseMessage) -> None: - messages = self.messages.append(_message_to_dict(message)) - with open(os.path.join(storage_path, session_id), 'w') as f: - json.dump(f, messages) - - def clear(self): - with open(os.path.join(storage_path, session_id), 'w') as f: - f.write("[]") - """ - - messages: List[BaseMessage] - """A list of Messages stored in-memory.""" - - def add_user_message(self, message: str) -> None: - """Convenience method for adding a human message string to the store. - - Args: - message: The string contents of a human message. - """ - self.add_message(HumanMessage(content=message)) - - def add_ai_message(self, message: str) -> None: - """Convenience method for adding an AI message string to the store. - - Args: - message: The string contents of an AI message. - """ - self.add_message(AIMessage(content=message)) - - # TODO: Make this an abstractmethod. - def add_message(self, message: BaseMessage) -> None: - """Add a Message object to the store. - - Args: - message: A BaseMessage object to store. - """ - raise NotImplementedError - - @abstractmethod - def clear(self) -> None: - """Remove all messages from the store""" - - -class Document(Serializable): - """Class for storing a piece of text and associated metadata.""" - - page_content: str - """String text.""" - metadata: dict = Field(default_factory=dict) - """Arbitrary metadata about the page content (e.g., source, relationships to other - documents, etc.). - """ - - -class BaseRetriever(ABC): - """Abstract base class for a Document retrieval system. - - A retrieval system is defined as something that can take string queries and return - the most 'relevant' Documents from some source. - - Example: - .. code-block:: python - - class TFIDFRetriever(BaseRetriever, BaseModel): - vectorizer: Any - docs: List[Document] - tfidf_array: Any - k: int = 4 - - class Config: - arbitrary_types_allowed = True - - def get_relevant_documents(self, query: str) -> List[Document]: - from sklearn.metrics.pairwise import cosine_similarity - - # Ip -- (n_docs,x), Op -- (n_docs,n_Feats) - query_vec = self.vectorizer.transform([query]) - # Op -- (n_docs,1) -- Cosine Sim with each doc - results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) - return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] - - async def aget_relevant_documents(self, query: str) -> List[Document]: - raise NotImplementedError - - """ # noqa: E501 - - _new_arg_supported: bool = False - _expects_other_args: bool = False - - def __init_subclass__(cls, **kwargs: Any) -> None: - super().__init_subclass__(**kwargs) - # Version upgrade for old retrievers that implemented the public - # methods directly. - if cls.get_relevant_documents != BaseRetriever.get_relevant_documents: - warnings.warn( - "Retrievers must implement abstract `_get_relevant_documents` method" - " instead of `get_relevant_documents`", - DeprecationWarning, - ) - swap = cls.get_relevant_documents - cls.get_relevant_documents = ( # type: ignore[assignment] - BaseRetriever.get_relevant_documents - ) - cls._get_relevant_documents = swap # type: ignore[assignment] - if ( - hasattr(cls, "aget_relevant_documents") - and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents - ): - warnings.warn( - "Retrievers must implement abstract `_aget_relevant_documents` method" - " instead of `aget_relevant_documents`", - DeprecationWarning, - ) - aswap = cls.aget_relevant_documents - cls.aget_relevant_documents = ( # type: ignore[assignment] - BaseRetriever.aget_relevant_documents - ) - cls._aget_relevant_documents = aswap # type: ignore[assignment] - parameters = signature(cls._get_relevant_documents).parameters - cls._new_arg_supported = parameters.get("run_manager") is not None - # If a V1 retriever broke the interface and expects additional arguments - cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2 - - @abstractmethod - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any - ) -> List[Document]: - """Get documents relevant to a query. - Args: - query: String to find relevant documents for. - run_manager: The callbacks handler to use. - Returns: - List of relevant documents - """ - - @abstractmethod - async def _aget_relevant_documents( - self, - query: str, - *, - run_manager: AsyncCallbackManagerForRetrieverRun, - **kwargs: Any, - ) -> List[Document]: - """Asynchronously get documents relevant to a query. - Args: - query: string to find relevant documents for - run_manager: The callbacks handler to use - Returns: - List of relevant documents - """ - - def get_relevant_documents( - self, query: str, *, callbacks: Callbacks = None, **kwargs: Any - ) -> List[Document]: - """Retrieve documents relevant to a query. - Args: - query: String to find relevant documents for. - callbacks: Callback manager or list of callbacks. - Returns: - List of relevant documents - """ - from langchain.callbacks.manager import CallbackManager - - callback_manager = CallbackManager.configure( - callbacks, None, verbose=kwargs.get("verbose", False) - ) - run_manager = callback_manager.on_retriever_start( - query, - **kwargs, - ) - try: - if self._new_arg_supported: - result = self._get_relevant_documents( - query, run_manager=run_manager, **kwargs - ) - elif self._expects_other_args: - result = self._get_relevant_documents(query, **kwargs) - else: - result = self._get_relevant_documents(query) # type: ignore[call-arg] - except Exception as e: - run_manager.on_retriever_error(e) - raise e - else: - run_manager.on_retriever_end( - result, - **kwargs, - ) - return result - - async def aget_relevant_documents( - self, query: str, *, callbacks: Callbacks = None, **kwargs: Any - ) -> List[Document]: - """Asynchronously get documents relevant to a query. - Args: - query: string to find relevant documents for - callbacks: Callback manager or list of callbacks - Returns: - List of relevant documents - """ - from langchain.callbacks.manager import AsyncCallbackManager - - callback_manager = AsyncCallbackManager.configure( - callbacks, None, verbose=kwargs.get("verbose", False) - ) - run_manager = await callback_manager.on_retriever_start( - query, - **kwargs, - ) - try: - if self._new_arg_supported: - result = await self._aget_relevant_documents( - query, run_manager=run_manager, **kwargs - ) - elif self._expects_other_args: - result = await self._aget_relevant_documents(query, **kwargs) - else: - result = await self._aget_relevant_documents( - query, # type: ignore[call-arg] - ) - except Exception as e: - await run_manager.on_retriever_error(e) - raise e - else: - await run_manager.on_retriever_end( - result, - **kwargs, - ) - return result - - -# For backwards compatibility -Memory = BaseMemory - -T = TypeVar("T") - - -class BaseLLMOutputParser(Serializable, ABC, Generic[T]): - """Abstract base class for parsing the outputs of a model.""" - - @abstractmethod - def parse_result(self, result: List[Generation]) -> T: - """Parse a list of candidate model Generations into a specific format. - - Args: - result: A list of Generations to be parsed. The Generations are assumed - to be different candidate outputs for a single model input. - - Returns: - Structured output. - """ - - -class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]): - """Class to parse the output of an LLM call. - - Output parsers help structure language model responses. - - Example: - .. code-block:: python - - class BooleanOutputParser(BaseOutputParser[bool]): - true_val: str = "YES" - false_val: str = "NO" - - def parse(self, text: str) -> bool: - cleaned_text = text.strip().upper() - if cleaned_text not in (self.true_val.upper(), self.false_val.upper()): - raise OutputParserException( - f"BooleanOutputParser expected output value to either be " - f"{self.true_val} or {self.false_val} (case-insensitive). " - f"Received {cleaned_text}." - ) - return cleaned_text == self.true_val.upper() - - @property - def _type(self) -> str: - return "boolean_output_parser" - """ # noqa: E501 - - def parse_result(self, result: List[Generation]) -> T: - """Parse a list of candidate model Generations into a specific format. - - The return value is parsed from only the first Generation in the result, which - is assumed to be the highest-likelihood Generation. - - Args: - result: A list of Generations to be parsed. The Generations are assumed - to be different candidate outputs for a single model input. - - Returns: - Structured output. - """ - return self.parse(result[0].text) - - @abstractmethod - def parse(self, text: str) -> T: - """Parse a single string model output into some structure. - - Args: - text: String output of language model. - - Returns: - Structured output. - """ - - # TODO: rename 'completion' -> 'text'. - def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: - """Parse the output of an LLM call with the input prompt for context. - - The prompt is largely provided in the event the OutputParser wants - to retry or fix the output in some way, and needs information from - the prompt to do so. - - Args: - completion: String output of language model. - prompt: Input PromptValue. - - Returns: - Structured output - """ - return self.parse(completion) - - def get_format_instructions(self) -> str: - """Instructions on how the LLM output should be formatted.""" - raise NotImplementedError - - @property - def _type(self) -> str: - """Return the output parser type for serialization.""" - raise NotImplementedError( - f"_type property is not implemented in class {self.__class__.__name__}." - " This is required for serialization." - ) - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of output parser.""" - output_parser_dict = super().dict(**kwargs) - output_parser_dict["_type"] = self._type - return output_parser_dict - - -class NoOpOutputParser(BaseOutputParser[str]): - """'No operation' OutputParser that returns the text as is.""" - - @property - def lc_serializable(self) -> bool: - """Whether the class LangChain serializable.""" - return True - - @property - def _type(self) -> str: - """Return the output parser type for serialization.""" - return "default" - - def parse(self, text: str) -> str: - """Returns the input text with no changes.""" - return text - - -class OutputParserException(ValueError): - """Exception that output parsers should raise to signify a parsing error. - - This exists to differentiate parsing errors from other code or execution errors - that also may arise inside the output parser. OutputParserExceptions will be - available to catch and handle in ways to fix the parsing error, while other - errors will be raised. - - Args: - error: The error that's being re-raised or an error message. - observation: String explanation of error which can be passed to a - model to try and remediate the issue. - llm_output: String model output which is error-ing. - send_to_llm: Whether to send the observation and llm_output back to an Agent - after an OutputParserException has been raised. This gives the underlying - model driving the agent the context that the previous output was improperly - structured, in the hopes that it will update the output to the correct - format. - """ - - def __init__( - self, - error: Any, - observation: Optional[str] = None, - llm_output: Optional[str] = None, - send_to_llm: bool = False, - ): - super(OutputParserException, self).__init__(error) - if send_to_llm: - if observation is None or llm_output is None: - raise ValueError( - "Arguments 'observation' & 'llm_output'" - " are required if 'send_to_llm' is True" - ) - self.observation = observation - self.llm_output = llm_output - self.send_to_llm = send_to_llm - - -class BaseDocumentTransformer(ABC): - """Abstract base class for document transformation systems. - - A document transformation system takes a sequence of Documents and returns a - sequence of transformed Documents. - - Example: - .. code-block:: python - - class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): - embeddings: Embeddings - similarity_fn: Callable = cosine_similarity - similarity_threshold: float = 0.95 - - class Config: - arbitrary_types_allowed = True - - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - stateful_documents = get_stateful_documents(documents) - embedded_documents = _get_embeddings_from_stateful_docs( - self.embeddings, stateful_documents - ) - included_idxs = _filter_similar_embeddings( - embedded_documents, self.similarity_fn, self.similarity_threshold - ) - return [stateful_documents[i] for i in sorted(included_idxs)] - - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - raise NotImplementedError - - """ # noqa: E501 - - @abstractmethod - def transform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - """Transform a list of documents. - - Args: - documents: A sequence of Documents to be transformed. - - Returns: - A list of transformed Documents. - """ - - @abstractmethod - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - """Asynchronously transform a list of documents. - - Args: - documents: A sequence of Documents to be transformed. - - Returns: - A list of transformed Documents. - """ diff --git a/langchain/schema/__init__.py b/langchain/schema/__init__.py new file mode 100644 index 0000000000..0821f2fb3f --- /dev/null +++ b/langchain/schema/__init__.py @@ -0,0 +1,67 @@ +from langchain.schema.agent import AgentAction, AgentFinish +from langchain.schema.document import BaseDocumentTransformer, Document +from langchain.schema.memory import BaseChatMessageHistory, BaseMemory +from langchain.schema.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, + _message_from_dict, + _message_to_dict, + get_buffer_string, + messages_from_dict, + messages_to_dict, +) +from langchain.schema.output import ( + ChatGeneration, + ChatResult, + Generation, + LLMResult, + RunInfo, +) +from langchain.schema.output_parser import ( + BaseLLMOutputParser, + BaseOutputParser, + NoOpOutputParser, + OutputParserException, +) +from langchain.schema.prompt import PromptValue +from langchain.schema.retriever import BaseRetriever + +RUN_KEY = "__run" +Memory = BaseMemory + +__all__ = [ + "BaseMemory", + "BaseChatMessageHistory", + "AgentFinish", + "AgentAction", + "Document", + "BaseDocumentTransformer", + "BaseMessage", + "ChatMessage", + "FunctionMessage", + "HumanMessage", + "AIMessage", + "SystemMessage", + "messages_from_dict", + "messages_to_dict", + "_message_to_dict", + "_message_from_dict", + "get_buffer_string", + "RunInfo", + "LLMResult", + "ChatResult", + "ChatGeneration", + "Generation", + "PromptValue", + "BaseRetriever", + "RUN_KEY", + "Memory", + "OutputParserException", + "NoOpOutputParser", + "BaseOutputParser", + "BaseLLMOutputParser", +] diff --git a/langchain/schema/agent.py b/langchain/schema/agent.py new file mode 100644 index 0000000000..5d58993cb5 --- /dev/null +++ b/langchain/schema/agent.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import NamedTuple, Union + + +@dataclass +class AgentAction: + """A full description of an action for an ActionAgent to execute.""" + + tool: str + """The name of the Tool to execute.""" + tool_input: Union[str, dict] + """The input to pass in to the Tool.""" + log: str + """Additional information to log about the action.""" + + +class AgentFinish(NamedTuple): + """The final return value of an ActionAgent.""" + + return_values: dict + """Dictionary of return values.""" + log: str + """Additional information to log about the return value""" diff --git a/langchain/schema/document.py b/langchain/schema/document.py new file mode 100644 index 0000000000..a05df15db6 --- /dev/null +++ b/langchain/schema/document.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Sequence + +from pydantic import Field + +from langchain.load.serializable import Serializable + + +class Document(Serializable): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + """String text.""" + metadata: dict = Field(default_factory=dict) + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + + +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + @abstractmethod + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ diff --git a/langchain/schema/memory.py b/langchain/schema/memory.py new file mode 100644 index 0000000000..14a07bb379 --- /dev/null +++ b/langchain/schema/memory.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +from langchain.load.serializable import Serializable +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage + + +class BaseMemory(Serializable, ABC): + """Base abstract class for memory in Chains. + + Memory refers to state in Chains. Memory can be used to store information about + past executions of a Chain and inject that information into the inputs of + future executions of the Chain. For example, for conversational Chains Memory + can be used to store conversations and automatically add them to future model + prompts so that the model has the necessary context to respond coherently to + the latest input. + + Example: + .. code-block:: python + + class SimpleMemory(BaseMemory): + memories: Dict[str, Any] = dict() + + @property + def memory_variables(self) -> List[str]: + return list(self.memories.keys()) + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + return self.memories + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + pass + + def clear(self) -> None: + pass + """ # noqa: E501 + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + @abstractmethod + def memory_variables(self) -> List[str]: + """The string keys this memory class will add to chain inputs.""" + + @abstractmethod + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return key-value pairs given the text input to the chain.""" + + @abstractmethod + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Save the context of this chain run to memory.""" + + @abstractmethod + def clear(self) -> None: + """Clear memory contents.""" + + +class BaseChatMessageHistory(ABC): + """Abstract base class for storing chat message history. + + See `ChatMessageHistory` for default implementation. + + Example: + .. code-block:: python + + class FileChatMessageHistory(BaseChatMessageHistory): + storage_path: str + session_id: str + + @property + def messages(self): + with open(os.path.join(storage_path, session_id), 'r:utf-8') as f: + messages = json.loads(f.read()) + return messages_from_dict(messages) + + def add_message(self, message: BaseMessage) -> None: + messages = self.messages.append(_message_to_dict(message)) + with open(os.path.join(storage_path, session_id), 'w') as f: + json.dump(f, messages) + + def clear(self): + with open(os.path.join(storage_path, session_id), 'w') as f: + f.write("[]") + """ + + messages: List[BaseMessage] + """A list of Messages stored in-memory.""" + + def add_user_message(self, message: str) -> None: + """Convenience method for adding a human message string to the store. + + Args: + message: The string contents of a human message. + """ + self.add_message(HumanMessage(content=message)) + + def add_ai_message(self, message: str) -> None: + """Convenience method for adding an AI message string to the store. + + Args: + message: The string contents of an AI message. + """ + self.add_message(AIMessage(content=message)) + + # TODO: Make this an abstractmethod. + def add_message(self, message: BaseMessage) -> None: + """Add a Message object to the store. + + Args: + message: A BaseMessage object to store. + """ + raise NotImplementedError + + @abstractmethod + def clear(self) -> None: + """Remove all messages from the store""" diff --git a/langchain/schema/messages.py b/langchain/schema/messages.py new file mode 100644 index 0000000000..c03ae20358 --- /dev/null +++ b/langchain/schema/messages.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from abc import abstractmethod +from typing import List, Sequence + +from pydantic import Field + +from langchain.load.serializable import Serializable + + +def get_buffer_string( + messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" +) -> str: + """Convert sequence of Messages to strings and concatenate them into one string. + + Args: + messages: Messages to be converted to strings. + human_prefix: The prefix to prepend to contents of HumanMessages. + ai_prefix: THe prefix to prepend to contents of AIMessages. + + Returns: + A single string concatenation of all input messages. + + Example: + .. code-block:: python + + from langchain.schema import AIMessage, HumanMessage + + messages = [ + HumanMessage(content="Hi, how are you?"), + AIMessage(content="Good, how are you?"), + ] + get_buffer_string(messages) + # -> "Human: Hi, how are you?\nAI: Good, how are you?" + """ + string_messages = [] + for m in messages: + if isinstance(m, HumanMessage): + role = human_prefix + elif isinstance(m, AIMessage): + role = ai_prefix + elif isinstance(m, SystemMessage): + role = "System" + elif isinstance(m, FunctionMessage): + role = "Function" + elif isinstance(m, ChatMessage): + role = m.role + else: + raise ValueError(f"Got unsupported message type: {m}") + message = f"{role}: {m.content}" + if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs: + message += f"{m.additional_kwargs['function_call']}" + string_messages.append(message) + + return "\n".join(string_messages) + + +class BaseMessage(Serializable): + """The base abstract Message class. + + Messages are the inputs and outputs of ChatModels. + """ + + content: str + """The string contents of the message.""" + + additional_kwargs: dict = Field(default_factory=dict) + """Any additional information.""" + + @property + @abstractmethod + def type(self) -> str: + """Type of the Message, used for serialization.""" + + @property + def lc_serializable(self) -> bool: + """Whether this class is LangChain serializable.""" + return True + + +class HumanMessage(BaseMessage): + """A Message from a human.""" + + example: bool = False + """Whether this Message is being passed in to the model as part of an example + conversation. + """ + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "human" + + +class AIMessage(BaseMessage): + """A Message from an AI.""" + + example: bool = False + """Whether this Message is being passed in to the model as part of an example + conversation. + """ + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "ai" + + +class SystemMessage(BaseMessage): + """A Message for priming AI behavior, usually passed in as the first of a sequence + of input messages. + """ + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "system" + + +class FunctionMessage(BaseMessage): + """A Message for passing the result of executing a function back to a model.""" + + name: str + """The name of the function that was executed.""" + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "function" + + +class ChatMessage(BaseMessage): + """A Message that can be assigned an arbitrary speaker (i.e. role).""" + + role: str + """The speaker / role of the Message.""" + + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "chat" + + +def _message_to_dict(message: BaseMessage) -> dict: + return {"type": message.type, "data": message.dict()} + + +def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]: + """Convert a sequence of Messages to a list of dictionaries. + + Args: + messages: Sequence of messages (as BaseMessages) to convert. + + Returns: + List of messages as dicts. + """ + return [_message_to_dict(m) for m in messages] + + +def _message_from_dict(message: dict) -> BaseMessage: + _type = message["type"] + if _type == "human": + return HumanMessage(**message["data"]) + elif _type == "ai": + return AIMessage(**message["data"]) + elif _type == "system": + return SystemMessage(**message["data"]) + elif _type == "chat": + return ChatMessage(**message["data"]) + else: + raise ValueError(f"Got unexpected type: {_type}") + + +def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: + """Convert a sequence of messages from dicts to Message objects. + + Args: + messages: Sequence of messages (as dicts) to convert. + + Returns: + List of messages (BaseMessages). + """ + return [_message_from_dict(m) for m in messages] diff --git a/langchain/schema/output.py b/langchain/schema/output.py new file mode 100644 index 0000000000..c085a49524 --- /dev/null +++ b/langchain/schema/output.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Dict, List, Optional +from uuid import UUID + +from pydantic import BaseModel, root_validator + +from langchain.load.serializable import Serializable +from langchain.schema.messages import BaseMessage + + +class Generation(Serializable): + """A single text generation output.""" + + text: str + """Generated text output.""" + + generation_info: Optional[Dict[str, Any]] = None + """Raw response from the provider. May include things like the + reason for finishing or token log probabilities. + """ + # TODO: add log probs as separate attribute + + @property + def lc_serializable(self) -> bool: + """Whether this class is LangChain serializable.""" + return True + + +class ChatGeneration(Generation): + """A single chat generation output.""" + + text: str = "" + """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" + message: BaseMessage + """The message output by the chat model.""" + + @root_validator + def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Set the text attribute to be the contents of the message.""" + values["text"] = values["message"].content + return values + + +class RunInfo(BaseModel): + """Class that contains metadata for a single execution of a Chain or model.""" + + run_id: UUID + """A unique identifier for the model or chain run.""" + + +class ChatResult(BaseModel): + """Class that contains all results for a single chat model call.""" + + generations: List[ChatGeneration] + """List of the chat generations. This is a List because an input can have multiple + candidate generations. + """ + llm_output: Optional[dict] = None + """For arbitrary LLM provider specific output.""" + + +class LLMResult(BaseModel): + """Class that contains all results for a batched LLM call.""" + + generations: List[List[Generation]] + """List of generated outputs. This is a List[List[]] because + each input could have multiple candidate generations.""" + llm_output: Optional[dict] = None + """Arbitrary LLM provider-specific output.""" + run: Optional[List[RunInfo]] = None + """List of metadata info for model call for each input.""" + + def flatten(self) -> List[LLMResult]: + """Flatten generations into a single list. + + Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult + contains only a single Generation. If token usage information is available, + it is kept only for the LLMResult corresponding to the top-choice + Generation, to avoid over-counting of token usage downstream. + + Returns: + List of LLMResults where each returned LLMResult contains a single + Generation. + """ + llm_results = [] + for i, gen_list in enumerate(self.generations): + # Avoid double counting tokens in OpenAICallback + if i == 0: + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=self.llm_output, + ) + ) + else: + if self.llm_output is not None: + llm_output = deepcopy(self.llm_output) + llm_output["token_usage"] = dict() + else: + llm_output = None + llm_results.append( + LLMResult( + generations=[gen_list], + llm_output=llm_output, + ) + ) + return llm_results + + def __eq__(self, other: object) -> bool: + """Check for LLMResult equality by ignoring any metadata related to runs.""" + if not isinstance(other, LLMResult): + return NotImplemented + return ( + self.generations == other.generations + and self.llm_output == other.llm_output + ) diff --git a/langchain/schema/output_parser.py b/langchain/schema/output_parser.py new file mode 100644 index 0000000000..ee2ad9736a --- /dev/null +++ b/langchain/schema/output_parser.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, List, Optional, TypeVar + +from langchain.load.serializable import Serializable +from langchain.schema.output import Generation +from langchain.schema.prompt import PromptValue + +T = TypeVar("T") + + +class BaseLLMOutputParser(Serializable, ABC, Generic[T]): + """Abstract base class for parsing the outputs of a model.""" + + @abstractmethod + def parse_result(self, result: List[Generation]) -> T: + """Parse a list of candidate model Generations into a specific format. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + + +class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]): + """Class to parse the output of an LLM call. + + Output parsers help structure language model responses. + + Example: + .. code-block:: python + + class BooleanOutputParser(BaseOutputParser[bool]): + true_val: str = "YES" + false_val: str = "NO" + + def parse(self, text: str) -> bool: + cleaned_text = text.strip().upper() + if cleaned_text not in (self.true_val.upper(), self.false_val.upper()): + raise OutputParserException( + f"BooleanOutputParser expected output value to either be " + f"{self.true_val} or {self.false_val} (case-insensitive). " + f"Received {cleaned_text}." + ) + return cleaned_text == self.true_val.upper() + + @property + def _type(self) -> str: + return "boolean_output_parser" + """ # noqa: E501 + + def parse_result(self, result: List[Generation]) -> T: + """Parse a list of candidate model Generations into a specific format. + + The return value is parsed from only the first Generation in the result, which + is assumed to be the highest-likelihood Generation. + + Args: + result: A list of Generations to be parsed. The Generations are assumed + to be different candidate outputs for a single model input. + + Returns: + Structured output. + """ + return self.parse(result[0].text) + + @abstractmethod + def parse(self, text: str) -> T: + """Parse a single string model output into some structure. + + Args: + text: String output of language model. + + Returns: + Structured output. + """ + + # TODO: rename 'completion' -> 'text'. + def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: + """Parse the output of an LLM call with the input prompt for context. + + The prompt is largely provided in the event the OutputParser wants + to retry or fix the output in some way, and needs information from + the prompt to do so. + + Args: + completion: String output of language model. + prompt: Input PromptValue. + + Returns: + Structured output + """ + return self.parse(completion) + + def get_format_instructions(self) -> str: + """Instructions on how the LLM output should be formatted.""" + raise NotImplementedError + + @property + def _type(self) -> str: + """Return the output parser type for serialization.""" + raise NotImplementedError( + f"_type property is not implemented in class {self.__class__.__name__}." + " This is required for serialization." + ) + + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of output parser.""" + output_parser_dict = super().dict(**kwargs) + output_parser_dict["_type"] = self._type + return output_parser_dict + + +class NoOpOutputParser(BaseOutputParser[str]): + """'No operation' OutputParser that returns the text as is.""" + + @property + def lc_serializable(self) -> bool: + """Whether the class LangChain serializable.""" + return True + + @property + def _type(self) -> str: + """Return the output parser type for serialization.""" + return "default" + + def parse(self, text: str) -> str: + """Returns the input text with no changes.""" + return text + + +class OutputParserException(ValueError): + """Exception that output parsers should raise to signify a parsing error. + + This exists to differentiate parsing errors from other code or execution errors + that also may arise inside the output parser. OutputParserExceptions will be + available to catch and handle in ways to fix the parsing error, while other + errors will be raised. + + Args: + error: The error that's being re-raised or an error message. + observation: String explanation of error which can be passed to a + model to try and remediate the issue. + llm_output: String model output which is error-ing. + send_to_llm: Whether to send the observation and llm_output back to an Agent + after an OutputParserException has been raised. This gives the underlying + model driving the agent the context that the previous output was improperly + structured, in the hopes that it will update the output to the correct + format. + """ + + def __init__( + self, + error: Any, + observation: Optional[str] = None, + llm_output: Optional[str] = None, + send_to_llm: bool = False, + ): + super(OutputParserException, self).__init__(error) + if send_to_llm: + if observation is None or llm_output is None: + raise ValueError( + "Arguments 'observation' & 'llm_output'" + " are required if 'send_to_llm' is True" + ) + self.observation = observation + self.llm_output = llm_output + self.send_to_llm = send_to_llm diff --git a/langchain/schema/prompt.py b/langchain/schema/prompt.py new file mode 100644 index 0000000000..d273af0dbf --- /dev/null +++ b/langchain/schema/prompt.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from langchain.load.serializable import Serializable +from langchain.schema.messages import BaseMessage + + +class PromptValue(Serializable, ABC): + """Base abstract class for inputs to any language model. + + PromptValues can be converted to both LLM (pure text-generation) inputs and + ChatModel inputs. + """ + + @abstractmethod + def to_string(self) -> str: + """Return prompt value as string.""" + + @abstractmethod + def to_messages(self) -> List[BaseMessage]: + """Return prompt as a list of Messages.""" diff --git a/langchain/schema/retriever.py b/langchain/schema/retriever.py new file mode 100644 index 0000000000..d3e59fb75c --- /dev/null +++ b/langchain/schema/retriever.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import warnings +from abc import ABC, abstractmethod +from inspect import signature +from typing import TYPE_CHECKING, Any, List + +from langchain.schema.document import Document + +if TYPE_CHECKING: + from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, + Callbacks, + ) + + +class BaseRetriever(ABC): + """Abstract base class for a Document retrieval system. + + A retrieval system is defined as something that can take string queries and return + the most 'relevant' Documents from some source. + + Example: + .. code-block:: python + + class TFIDFRetriever(BaseRetriever, BaseModel): + vectorizer: Any + docs: List[Document] + tfidf_array: Any + k: int = 4 + + class Config: + arbitrary_types_allowed = True + + def get_relevant_documents(self, query: str) -> List[Document]: + from sklearn.metrics.pairwise import cosine_similarity + + # Ip -- (n_docs,x), Op -- (n_docs,n_Feats) + query_vec = self.vectorizer.transform([query]) + # Op -- (n_docs,1) -- Cosine Sim with each doc + results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,)) + return [self.docs[i] for i in results.argsort()[-self.k :][::-1]] + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError + + """ # noqa: E501 + + _new_arg_supported: bool = False + _expects_other_args: bool = False + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + # Version upgrade for old retrievers that implemented the public + # methods directly. + if cls.get_relevant_documents != BaseRetriever.get_relevant_documents: + warnings.warn( + "Retrievers must implement abstract `_get_relevant_documents` method" + " instead of `get_relevant_documents`", + DeprecationWarning, + ) + swap = cls.get_relevant_documents + cls.get_relevant_documents = ( # type: ignore[assignment] + BaseRetriever.get_relevant_documents + ) + cls._get_relevant_documents = swap # type: ignore[assignment] + if ( + hasattr(cls, "aget_relevant_documents") + and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents + ): + warnings.warn( + "Retrievers must implement abstract `_aget_relevant_documents` method" + " instead of `aget_relevant_documents`", + DeprecationWarning, + ) + aswap = cls.aget_relevant_documents + cls.aget_relevant_documents = ( # type: ignore[assignment] + BaseRetriever.aget_relevant_documents + ) + cls._aget_relevant_documents = aswap # type: ignore[assignment] + parameters = signature(cls._get_relevant_documents).parameters + cls._new_arg_supported = parameters.get("run_manager") is not None + # If a V1 retriever broke the interface and expects additional arguments + cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2 + + @abstractmethod + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any + ) -> List[Document]: + """Get documents relevant to a query. + Args: + query: String to find relevant documents for. + run_manager: The callbacks handler to use. + Returns: + List of relevant documents + """ + + @abstractmethod + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + **kwargs: Any, + ) -> List[Document]: + """Asynchronously get documents relevant to a query. + Args: + query: string to find relevant documents for + run_manager: The callbacks handler to use + Returns: + List of relevant documents + """ + + def get_relevant_documents( + self, query: str, *, callbacks: Callbacks = None, **kwargs: Any + ) -> List[Document]: + """Retrieve documents relevant to a query. + Args: + query: String to find relevant documents for. + callbacks: Callback manager or list of callbacks. + Returns: + List of relevant documents + """ + from langchain.callbacks.manager import CallbackManager + + callback_manager = CallbackManager.configure( + callbacks, None, verbose=kwargs.get("verbose", False) + ) + run_manager = callback_manager.on_retriever_start( + query, + **kwargs, + ) + try: + if self._new_arg_supported: + result = self._get_relevant_documents( + query, run_manager=run_manager, **kwargs + ) + elif self._expects_other_args: + result = self._get_relevant_documents(query, **kwargs) + else: + result = self._get_relevant_documents(query) # type: ignore[call-arg] + except Exception as e: + run_manager.on_retriever_error(e) + raise e + else: + run_manager.on_retriever_end( + result, + **kwargs, + ) + return result + + async def aget_relevant_documents( + self, query: str, *, callbacks: Callbacks = None, **kwargs: Any + ) -> List[Document]: + """Asynchronously get documents relevant to a query. + Args: + query: string to find relevant documents for + callbacks: Callback manager or list of callbacks + Returns: + List of relevant documents + """ + from langchain.callbacks.manager import AsyncCallbackManager + + callback_manager = AsyncCallbackManager.configure( + callbacks, None, verbose=kwargs.get("verbose", False) + ) + run_manager = await callback_manager.on_retriever_start( + query, + **kwargs, + ) + try: + if self._new_arg_supported: + result = await self._aget_relevant_documents( + query, run_manager=run_manager, **kwargs + ) + elif self._expects_other_args: + result = await self._aget_relevant_documents(query, **kwargs) + else: + result = await self._aget_relevant_documents( + query, # type: ignore[call-arg] + ) + except Exception as e: + await run_manager.on_retriever_error(e) + raise e + else: + await run_manager.on_retriever_end( + result, + **kwargs, + ) + return result diff --git a/tests/integration_tests/chat_models/test_anthropic.py b/tests/integration_tests/chat_models/test_anthropic.py index f0c2de7980..6704bde096 100644 --- a/tests/integration_tests/chat_models/test_anthropic.py +++ b/tests/integration_tests/chat_models/test_anthropic.py @@ -6,12 +6,10 @@ import pytest from langchain.callbacks.manager import CallbackManager from langchain.chat_models.anthropic import ChatAnthropic from langchain.schema import ( - AIMessage, - BaseMessage, ChatGeneration, - HumanMessage, LLMResult, ) +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/integration_tests/chat_models/test_google_palm.py b/tests/integration_tests/chat_models/test_google_palm.py index a95419e60a..09ded60064 100644 --- a/tests/integration_tests/chat_models/test_google_palm.py +++ b/tests/integration_tests/chat_models/test_google_palm.py @@ -8,13 +8,11 @@ import pytest from langchain.chat_models import ChatGooglePalm from langchain.schema import ( - BaseMessage, ChatGeneration, ChatResult, - HumanMessage, LLMResult, - SystemMessage, ) +from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage def test_chat_google_palm() -> None: diff --git a/tests/integration_tests/chat_models/test_openai.py b/tests/integration_tests/chat_models/test_openai.py index 568a92222e..af3157ed6d 100644 --- a/tests/integration_tests/chat_models/test_openai.py +++ b/tests/integration_tests/chat_models/test_openai.py @@ -6,13 +6,11 @@ import pytest from langchain.callbacks.manager import CallbackManager from langchain.chat_models.openai import ChatOpenAI from langchain.schema import ( - BaseMessage, ChatGeneration, ChatResult, - HumanMessage, LLMResult, - SystemMessage, ) +from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/integration_tests/chat_models/test_promptlayer_openai.py b/tests/integration_tests/chat_models/test_promptlayer_openai.py index ab68a0850b..3e5e9f8850 100644 --- a/tests/integration_tests/chat_models/test_promptlayer_openai.py +++ b/tests/integration_tests/chat_models/test_promptlayer_openai.py @@ -5,13 +5,11 @@ import pytest from langchain.callbacks.manager import CallbackManager from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.schema import ( - BaseMessage, ChatGeneration, ChatResult, - HumanMessage, LLMResult, - SystemMessage, ) +from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/integration_tests/chat_models/test_vertexai.py b/tests/integration_tests/chat_models/test_vertexai.py index cb69d68ff8..d4d9ed9778 100644 --- a/tests/integration_tests/chat_models/test_vertexai.py +++ b/tests/integration_tests/chat_models/test_vertexai.py @@ -13,11 +13,7 @@ import pytest from langchain.chat_models import ChatVertexAI from langchain.chat_models.vertexai import _MessagePair, _parse_chat_history -from langchain.schema import ( - AIMessage, - HumanMessage, - SystemMessage, -) +from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage def test_vertexai_single_call() -> None: diff --git a/tests/integration_tests/memory/test_cassandra.py b/tests/integration_tests/memory/test_cassandra.py index 71c45c7a86..3e6572f58b 100644 --- a/tests/integration_tests/memory/test_cassandra.py +++ b/tests/integration_tests/memory/test_cassandra.py @@ -8,10 +8,7 @@ from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories.cassandra import ( CassandraChatMessageHistory, ) -from langchain.schema import ( - AIMessage, - HumanMessage, -) +from langchain.schema.messages import AIMessage, HumanMessage def _chat_message_history( diff --git a/tests/integration_tests/memory/test_cosmos_db.py b/tests/integration_tests/memory/test_cosmos_db.py index fd0cd99f6b..1792692bc1 100644 --- a/tests/integration_tests/memory/test_cosmos_db.py +++ b/tests/integration_tests/memory/test_cosmos_db.py @@ -3,7 +3,7 @@ import os from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory -from langchain.schema import _message_to_dict +from langchain.schema.messages import _message_to_dict # Replace these with your Azure Cosmos DB endpoint and key endpoint = os.environ["COSMOS_DB_ENDPOINT"] diff --git a/tests/integration_tests/memory/test_firestore.py b/tests/integration_tests/memory/test_firestore.py index 0391b39ef8..4dcff90f50 100644 --- a/tests/integration_tests/memory/test_firestore.py +++ b/tests/integration_tests/memory/test_firestore.py @@ -2,7 +2,7 @@ import json from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import FirestoreChatMessageHistory -from langchain.schema import _message_to_dict +from langchain.schema.messages import _message_to_dict def test_memory_with_message_store() -> None: diff --git a/tests/integration_tests/memory/test_momento.py b/tests/integration_tests/memory/test_momento.py index 0260f6dba3..99f2327a7a 100644 --- a/tests/integration_tests/memory/test_momento.py +++ b/tests/integration_tests/memory/test_momento.py @@ -14,7 +14,7 @@ from momento import CacheClient, Configurations, CredentialProvider from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import MomentoChatMessageHistory -from langchain.schema import _message_to_dict +from langchain.schema.messages import _message_to_dict def random_string() -> str: diff --git a/tests/integration_tests/memory/test_mongodb.py b/tests/integration_tests/memory/test_mongodb.py index 9e1b0f0060..a57ccd0522 100644 --- a/tests/integration_tests/memory/test_mongodb.py +++ b/tests/integration_tests/memory/test_mongodb.py @@ -3,7 +3,7 @@ import os from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import MongoDBChatMessageHistory -from langchain.schema import _message_to_dict +from langchain.schema.messages import _message_to_dict # Replace these with your mongodb connection string connection_string = os.environ["MONGODB_CONNECTION_STRING"] diff --git a/tests/integration_tests/memory/test_redis.py b/tests/integration_tests/memory/test_redis.py index 16ea653c30..547f6ab3c2 100644 --- a/tests/integration_tests/memory/test_redis.py +++ b/tests/integration_tests/memory/test_redis.py @@ -2,7 +2,7 @@ import json from langchain.memory import ConversationBufferMemory from langchain.memory.chat_message_histories import RedisChatMessageHistory -from langchain.schema import _message_to_dict +from langchain.schema.messages import _message_to_dict def test_memory_with_message_store() -> None: diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index f607b3c708..dc7a8e777b 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -6,7 +6,7 @@ from uuid import UUID from pydantic import BaseModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler -from langchain.schema import BaseMessage +from langchain.schema.messages import BaseMessage class BaseFakeCallbackHandler(BaseModel): diff --git a/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/tests/unit_tests/callbacks/tracers/test_base_tracer.py index ca7735467d..c73253a2b9 100644 --- a/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -11,7 +11,8 @@ from freezegun import freeze_time from langchain.callbacks.manager import CallbackManager from langchain.callbacks.tracers.base import BaseTracer, TracerException from langchain.callbacks.tracers.schemas import Run -from langchain.schema import HumanMessage, LLMResult +from langchain.schema import LLMResult +from langchain.schema.messages import HumanMessage SERIALIZED = {"id": ["llm"]} SERIALIZED_CHAT = {"id": ["chat_model"]} diff --git a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py index f57b93a074..a7fab61127 100644 --- a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py +++ b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -18,7 +18,8 @@ from langchain.callbacks.tracers.langchain_v1 import ( TracerSessionV1, ) from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base -from langchain.schema import HumanMessage, LLMResult +from langchain.schema import LLMResult +from langchain.schema.messages import HumanMessage TEST_SESSION_ID = 2023 diff --git a/tests/unit_tests/chat_models/test_google_palm.py b/tests/unit_tests/chat_models/test_google_palm.py index 0ca7fb4eb2..8bcb9ee78f 100644 --- a/tests/unit_tests/chat_models/test_google_palm.py +++ b/tests/unit_tests/chat_models/test_google_palm.py @@ -7,11 +7,7 @@ from langchain.chat_models.google_palm import ( ChatGooglePalmError, _messages_to_prompt_dict, ) -from langchain.schema import ( - AIMessage, - HumanMessage, - SystemMessage, -) +from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage def test_messages_to_prompt_dict_with_valid_messages() -> None: diff --git a/tests/unit_tests/chat_models/test_openai.py b/tests/unit_tests/chat_models/test_openai.py index 9720eb9888..ad05133f18 100644 --- a/tests/unit_tests/chat_models/test_openai.py +++ b/tests/unit_tests/chat_models/test_openai.py @@ -5,9 +5,7 @@ import json from langchain.chat_models.openai import ( _convert_dict_to_message, ) -from langchain.schema import ( - FunctionMessage, -) +from langchain.schema.messages import FunctionMessage def test_function_message_dict_to_function_message() -> None: diff --git a/tests/unit_tests/llms/fake_chat_model.py b/tests/unit_tests/llms/fake_chat_model.py index f68a7532d2..5b8218720d 100644 --- a/tests/unit_tests/llms/fake_chat_model.py +++ b/tests/unit_tests/llms/fake_chat_model.py @@ -6,7 +6,8 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.chat_models.base import SimpleChatModel -from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult +from langchain.schema import ChatGeneration, ChatResult +from langchain.schema.messages import AIMessage, BaseMessage class FakeChatModel(SimpleChatModel): diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index 72babf3a2a..30cc19cd3f 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -1,7 +1,7 @@ """Test LLM callbacks.""" from langchain.chat_models.fake import FakeListChatModel from langchain.llms.fake import FakeListLLM -from langchain.schema import HumanMessage +from langchain.schema.messages import HumanMessage from tests.unit_tests.callbacks.fake_callback_handler import ( FakeCallbackHandler, FakeCallbackHandlerWithChatStart, diff --git a/tests/unit_tests/memory/chat_message_histories/test_file.py b/tests/unit_tests/memory/chat_message_histories/test_file.py index 13962370e5..a2351671c4 100644 --- a/tests/unit_tests/memory/chat_message_histories/test_file.py +++ b/tests/unit_tests/memory/chat_message_histories/test_file.py @@ -5,7 +5,7 @@ from typing import Generator import pytest from langchain.memory.chat_message_histories import FileChatMessageHistory -from langchain.schema import AIMessage, HumanMessage +from langchain.schema.messages import AIMessage, HumanMessage @pytest.fixture diff --git a/tests/unit_tests/memory/chat_message_histories/test_sql.py b/tests/unit_tests/memory/chat_message_histories/test_sql.py index 0299ad0ac7..42cff47b04 100644 --- a/tests/unit_tests/memory/chat_message_histories/test_sql.py +++ b/tests/unit_tests/memory/chat_message_histories/test_sql.py @@ -4,7 +4,7 @@ from typing import Tuple import pytest from langchain.memory.chat_message_histories import SQLChatMessageHistory -from langchain.schema import AIMessage, HumanMessage +from langchain.schema.messages import AIMessage, HumanMessage # @pytest.fixture(params=[("SQLite"), ("postgresql")]) diff --git a/tests/unit_tests/memory/chat_message_histories/test_zep.py b/tests/unit_tests/memory/chat_message_histories/test_zep.py index 8dd1b4ace0..78967657ed 100644 --- a/tests/unit_tests/memory/chat_message_histories/test_zep.py +++ b/tests/unit_tests/memory/chat_message_histories/test_zep.py @@ -4,7 +4,7 @@ import pytest from pytest_mock import MockerFixture from langchain.memory.chat_message_histories import ZepChatMessageHistory -from langchain.schema import AIMessage, HumanMessage +from langchain.schema.messages import AIMessage, HumanMessage if TYPE_CHECKING: from zep_python import ZepClient diff --git a/tests/unit_tests/prompts/test_chat.py b/tests/unit_tests/prompts/test_chat.py index 17c114e640..d355b90dfc 100644 --- a/tests/unit_tests/prompts/test_chat.py +++ b/tests/unit_tests/prompts/test_chat.py @@ -13,7 +13,7 @@ from langchain.prompts.chat import ( HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema import HumanMessage +from langchain.schema.messages import HumanMessage def create_messages() -> List[BaseMessagePromptTemplate]: diff --git a/tests/unit_tests/test_cache.py b/tests/unit_tests/test_cache.py index 0cbe324e1f..279d3391f5 100644 --- a/tests/unit_tests/test_cache.py +++ b/tests/unit_tests/test_cache.py @@ -16,12 +16,10 @@ from langchain.chat_models.base import BaseChatModel, dumps from langchain.llms import FakeListLLM from langchain.llms.base import BaseLLM from langchain.schema import ( - AIMessage, - BaseMessage, ChatGeneration, Generation, - HumanMessage, ) +from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage def get_sqlite_cache() -> SQLAlchemyCache: diff --git a/tests/unit_tests/test_schema.py b/tests/unit_tests/test_schema.py index ef9d99187b..facbf93507 100644 --- a/tests/unit_tests/test_schema.py +++ b/tests/unit_tests/test_schema.py @@ -2,7 +2,7 @@ import unittest -from langchain.schema import ( +from langchain.schema.messages import ( AIMessage, HumanMessage, SystemMessage,