Harrison/split schema dir (#7025)

should be no functional changes

also keep __init__ exposing a lot for backwards compat

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/7028/head
Harrison Chase 1 year ago committed by GitHub
parent 556c425042
commit 3bfe7cf467
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,10 +32,10 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
BaseOutputParser,
OutputParserException,
)
from langchain.schema.messages import BaseMessage
from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout

@ -20,7 +20,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import SystemMessage
from langchain.schema.messages import SystemMessage
from langchain.tools.python.tool import PythonAstREPLTool

@ -10,7 +10,7 @@ from langchain.agents.types import AgentType
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema import SystemMessage
from langchain.schema.messages import SystemMessage
from langchain.tools.python.tool import PythonREPLTool

@ -20,7 +20,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain.schema import AIMessage, SystemMessage
from langchain.schema.messages import AIMessage, SystemMessage
def create_sql_agent(

@ -25,11 +25,9 @@ from langchain.prompts.chat import (
)
from langchain.schema import (
AgentAction,
AIMessage,
BaseMessage,
BaseOutputParser,
HumanMessage,
)
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.tools.base import BaseTool

@ -21,10 +21,12 @@ from langchain.prompts.chat import (
from langchain.schema import (
AgentAction,
AgentFinish,
OutputParserException,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
OutputParserException,
SystemMessage,
)
from langchain.tools import BaseTool

@ -21,10 +21,12 @@ from langchain.prompts.chat import (
from langchain.schema import (
AgentAction,
AgentFinish,
OutputParserException,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
OutputParserException,
SystemMessage,
)
from langchain.tools import BaseTool

@ -6,7 +6,8 @@ from typing import Any, List, Optional, Sequence, Set
from langchain.callbacks.manager import Callbacks
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
from langchain.schema import LLMResult, PromptValue
from langchain.schema.messages import BaseMessage, get_buffer_string
def _get_token_ids_default_method(text: str) -> List[int]:

@ -4,7 +4,10 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.schema import AgentAction, AgentFinish, BaseMessage, Document, LLMResult
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
class RetrieverManagerMixin:

@ -41,11 +41,10 @@ from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
Document,
LLMResult,
get_buffer_string,
)
from langchain.schema.messages import BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]

@ -7,12 +7,14 @@ from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import (
ChatGeneration,
LLMResult,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
HumanMessage,
LLMResult,
SystemMessage,
)

@ -17,7 +17,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSession,
)
from langchain.env import get_runtime_environment
from langchain.schema import BaseMessage, messages_to_dict
from langchain.schema.messages import BaseMessage, messages_to_dict
logger = logging.getLogger(__name__)
_LOGGED = set()

@ -16,7 +16,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSessionV1,
TracerSessionV1Base,
)
from langchain.schema import get_buffer_string
from langchain.schema.messages import get_buffer_string
from langchain.utils import raise_for_status_with_text

@ -22,7 +22,8 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseMessage, BaseRetriever, Document
from langchain.schema import BaseRetriever, Document
from langchain.schema.messages import BaseMessage
from langchain.vectorstores.base import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.

@ -9,7 +9,7 @@ from langchain.output_parsers.openai_functions import (
PydanticOutputFunctionsParser,
)
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import HumanMessage, SystemMessage
from langchain.schema.messages import HumanMessage, SystemMessage
class FactWithEvidence(BaseModel):

@ -11,7 +11,8 @@ from langchain.output_parsers.openai_functions import (
)
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import BaseLLMOutputParser, HumanMessage, SystemMessage
from langchain.schema import BaseLLMOutputParser
from langchain.schema.messages import HumanMessage, SystemMessage
class AnswerWithSources(BaseModel):

@ -7,11 +7,13 @@ from langchain.callbacks.manager import (
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
HumanMessage,
SystemMessage,
)

@ -19,15 +19,13 @@ from langchain.callbacks.manager import (
)
from langchain.load.dump import dumpd, dumps
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
PromptValue,
RunInfo,
)
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
def _get_verbosity() -> bool:

@ -3,7 +3,7 @@ from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessage
class FakeListChatModel(SimpleChatModel):

@ -19,11 +19,13 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
HumanMessage,
SystemMessage,
)

@ -30,11 +30,13 @@ from langchain.callbacks.manager import (
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
FunctionMessage,
HumanMessage,
SystemMessage,

@ -7,7 +7,8 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, ChatResult
from langchain.schema import ChatResult
from langchain.schema.messages import BaseMessage
class PromptLayerChatOpenAI(ChatOpenAI):

@ -11,10 +11,12 @@ from langchain.callbacks.manager import (
from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)

@ -31,10 +31,12 @@ from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.schema import (
BaseMessage,
ChatResult,
HumanMessage,
LLMResult,
)
from langchain.schema.messages import (
BaseMessage,
HumanMessage,
get_buffer_string,
messages_from_dict,
)

@ -1,8 +1,6 @@
"""Prompt for trajectory evaluation chain."""
# flake8: noqa
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain.schema.messages import HumanMessage, AIMessage, SystemMessage
from langchain.prompts.chat import (
ChatPromptTemplate,

@ -16,12 +16,10 @@ from langchain.experimental.autonomous_agents.autogpt.prompt_generator import (
)
from langchain.memory import ChatMessageHistory
from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
Document,
HumanMessage,
SystemMessage,
)
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain.tools.base import BaseTool
from langchain.tools.human.tool import HumanInputRun
from langchain.vectorstores.base import VectorStoreRetriever

@ -7,7 +7,7 @@ from langchain.experimental.autonomous_agents.autogpt.prompt_generator import ge
from langchain.prompts.chat import (
BaseChatPromptTemplate,
)
from langchain.schema import BaseMessage, HumanMessage, SystemMessage
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from langchain.tools.base import BaseTool
from langchain.vectorstores.base import VectorStoreRetriever

@ -9,7 +9,7 @@ from langchain.experimental.plan_and_execute.schema import (
Step,
)
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import SystemMessage
from langchain.schema.messages import SystemMessage
SYSTEM_PROMPT = (
"Let's first understand the problem and devise a plan to solve the problem."

@ -22,14 +22,12 @@ from langchain.callbacks.manager import (
)
from langchain.load.dump import dumpd
from langchain.schema import (
AIMessage,
BaseMessage,
Generation,
LLMResult,
PromptValue,
RunInfo,
get_buffer_string,
)
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
def _get_verbosity() -> bool:

@ -4,7 +4,7 @@ from pydantic import root_validator
from langchain.memory.chat_memory import BaseChatMemory, BaseMemory
from langchain.memory.utils import get_prompt_input_key
from langchain.schema import get_buffer_string
from langchain.schema.messages import get_buffer_string
class ConversationBufferMemory(BaseChatMemory):

@ -1,7 +1,7 @@
from typing import Any, Dict, List
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage, get_buffer_string
from langchain.schema.messages import BaseMessage, get_buffer_string
class ConversationBufferWindowMemory(BaseChatMemory):

@ -10,10 +10,8 @@ if typing.TYPE_CHECKING:
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
messages_from_dict,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
DEFAULT_TABLE_NAME = "message_store"
DEFAULT_TTL_SECONDS = None

@ -7,10 +7,8 @@ from typing import TYPE_CHECKING, Any, List, Optional, Type
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
messages_from_dict,
messages_to_dict,
)
from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict
logger = logging.getLogger(__name__)

@ -3,6 +3,8 @@ from typing import List, Optional
from langchain.schema import (
BaseChatMessageHistory,
)
from langchain.schema.messages import (
BaseMessage,
_message_to_dict,
messages_from_dict,

@ -5,10 +5,8 @@ from typing import List
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
messages_from_dict,
messages_to_dict,
)
from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict
logger = logging.getLogger(__name__)

@ -6,10 +6,8 @@ from typing import TYPE_CHECKING, List, Optional
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
messages_from_dict,
messages_to_dict,
)
from langchain.schema.messages import BaseMessage, messages_from_dict, messages_to_dict
logger = logging.getLogger(__name__)

@ -4,8 +4,8 @@ from pydantic import BaseModel
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
)
from langchain.schema.messages import BaseMessage
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):

@ -6,10 +6,8 @@ from typing import TYPE_CHECKING, Any, Optional
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
messages_from_dict,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
from langchain.utils import get_from_env
if TYPE_CHECKING:

@ -4,10 +4,8 @@ from typing import List
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
messages_from_dict,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
logger = logging.getLogger(__name__)

@ -4,10 +4,8 @@ from typing import List
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
messages_from_dict,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
logger = logging.getLogger(__name__)

@ -4,10 +4,8 @@ from typing import List, Optional
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
messages_from_dict,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
logger = logging.getLogger(__name__)

@ -12,10 +12,8 @@ from sqlalchemy.orm import sessionmaker
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
_message_to_dict,
messages_from_dict,
)
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
logger = logging.getLogger(__name__)

@ -4,11 +4,9 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
)
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
if TYPE_CHECKING:
from zep_python import Memory, MemorySearchResult, Message, NotFoundError

@ -14,7 +14,7 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseMessage, get_buffer_string
from langchain.schema.messages import BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)

@ -13,11 +13,7 @@ from langchain.memory.prompt import (
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import (
BaseMessage,
SystemMessage,
get_buffer_string,
)
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
class ConversationKGMemory(BaseChatMemory):

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
import requests
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string
from langchain.schema.messages import get_buffer_string
MANAGED_URL = "https://api.getmetal.io/v1/motorhead"
# LOCAL_URL = "http://localhost:8080"

@ -11,10 +11,8 @@ from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
SystemMessage,
get_buffer_string,
)
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
class SummarizerMixin(BaseModel):

@ -4,7 +4,7 @@ from pydantic import root_validator
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin
from langchain.schema import BaseMessage, get_buffer_string
from langchain.schema.messages import BaseMessage, get_buffer_string
class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):

@ -2,7 +2,7 @@ from typing import Any, Dict, List
from langchain.base_language import BaseLanguageModel
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage, get_buffer_string
from langchain.schema.messages import BaseMessage, get_buffer_string
class ConversationTokenBufferMemory(BaseChatMemory):

@ -1,6 +1,6 @@
from typing import Any, Dict, List
from langchain.schema import get_buffer_string # noqa: 401
from langchain.schema.messages import get_buffer_string # noqa: 401
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:

@ -11,7 +11,8 @@ from pydantic import Field, root_validator
from langchain.formatting import formatter
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
from langchain.schema import BaseOutputParser, PromptValue
from langchain.schema.messages import BaseMessage, HumanMessage
def jinja2_formatter(template: str, **kwargs: Any) -> str:

@ -8,16 +8,18 @@ from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import Field, root_validator
from langchain.load.serializable import Serializable
from langchain.memory.buffer import get_buffer_string
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
PromptValue,
)
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
PromptValue,
SystemMessage,
get_buffer_string,
)

@ -1,886 +0,0 @@
"""Common schema objects."""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from inspect import signature
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
NamedTuple,
Optional,
Sequence,
TypeVar,
Union,
)
from uuid import UUID
from pydantic import BaseModel, Field, root_validator
from langchain.load.serializable import Serializable
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Callbacks,
)
RUN_KEY = "__run"
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain.schema import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
@dataclass
class AgentAction:
"""A full description of an action for an ActionAgent to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action."""
class AgentFinish(NamedTuple):
"""The final return value of an ActionAgent."""
return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value"""
class Generation(Serializable):
"""A single text generation output."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
# TODO: add log probs as separate attribute
@property
def lc_serializable(self) -> bool:
"""Whether this class is LangChain serializable."""
return True
class BaseMessage(Serializable):
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: str
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
@property
@abstractmethod
def type(self) -> str:
"""Type of the Message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""Whether this class is LangChain serializable."""
return True
class HumanMessage(BaseMessage):
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class AIMessage(BaseMessage):
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat"
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of messages as dicts.
"""
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]
class ChatGeneration(Generation):
"""A single chat generation output."""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
values["text"] = values["message"].content
return values
class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model."""
run_id: UUID
"""A unique identifier for the model or chain run."""
class ChatResult(BaseModel):
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
"""Arbitrary LLM provider-specific output."""
run: Optional[List[RunInfo]] = None
"""List of metadata info for model call for each input."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
contains only a single Generation. If token usage information is available,
it is kept only for the LLMResult corresponding to the top-choice
Generation, to avoid over-counting of token usage downstream.
Returns:
List of LLMResults where each returned LLMResult contains a single
Generation.
"""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
"""Check for LLMResult equality by ignoring any metadata related to runs."""
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)
class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
@abstractmethod
def to_string(self) -> str:
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
class BaseMemory(Serializable, ABC):
"""Base abstract class for memory in Chains.
Memory refers to state in Chains. Memory can be used to store information about
past executions of a Chain and inject that information into the inputs of
future executions of the Chain. For example, for conversational Chains Memory
can be used to store conversations and automatically add them to future model
prompts so that the model has the necessary context to respond coherently to
the latest input.
Example:
.. code-block:: python
class SimpleMemory(BaseMemory):
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
pass
def clear(self) -> None:
pass
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this chain run to memory."""
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
def add_user_message(self, message: str) -> None:
"""Convenience method for adding a human message string to the store.
Args:
message: The string contents of a human message.
"""
self.add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Convenience method for adding an AI message string to the store.
Args:
message: The string contents of an AI message.
"""
self.add_message(AIMessage(content=message))
# TODO: Make this an abstractmethod.
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""
class Document(Serializable):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
class BaseRetriever(ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source.
Example:
.. code-block:: python
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
query_vec = self.vectorizer.transform([query])
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
""" # noqa: E501
_new_arg_supported: bool = False
_expects_other_args: bool = False
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for.
run_manager: The callbacks handler to use.
Returns:
List of relevant documents
"""
@abstractmethod
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
def get_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: String to find relevant documents for.
callbacks: Callback manager or list of callbacks.
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **kwargs
)
elif self._expects_other_args:
result = self._get_relevant_documents(query, **kwargs)
else:
result = self._get_relevant_documents(query) # type: ignore[call-arg]
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
async def aget_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = await callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **kwargs
)
elif self._expects_other_args:
result = await self._aget_relevant_documents(query, **kwargs)
else:
result = await self._aget_relevant_documents(
query, # type: ignore[call-arg]
)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result
# For backwards compatibility
Memory = BaseMemory
T = TypeVar("T")
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
Example:
.. code-block:: python
class BooleanOutputParser(BaseOutputParser[bool]):
true_val: str = "YES"
false_val: str = "NO"
def parse(self, text: str) -> bool:
cleaned_text = text.strip().upper()
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
raise OutputParserException(
f"BooleanOutputParser expected output value to either be "
f"{self.true_val} or {self.false_val} (case-insensitive). "
f"Received {cleaned_text}."
)
return cleaned_text == self.true_val.upper()
@property
def _type(self) -> str:
return "boolean_output_parser"
""" # noqa: E501
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return self.parse(result[0].text)
@abstractmethod
def parse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of language model.
Returns:
Structured output.
"""
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Parse the output of an LLM call with the input prompt for context.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
Args:
completion: String output of language model.
prompt: Input PromptValue.
Returns:
Structured output
"""
return self.parse(completion)
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
output_parser_dict["_type"] = self._type
return output_parser_dict
class NoOpOutputParser(BaseOutputParser[str]):
"""'No operation' OutputParser that returns the text as is."""
@property
def lc_serializable(self) -> bool:
"""Whether the class LangChain serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
class OutputParserException(ValueError):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
Args:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
llm_output: String model output which is error-ing.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
"""
def __init__(
self,
error: Any,
observation: Optional[str] = None,
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.
A document transformation system takes a sequence of Documents and returns a
sequence of transformed Documents.
Example:
.. code-block:: python
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95
class Config:
arbitrary_types_allowed = True
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings, stateful_documents
)
included_idxs = _filter_similar_embeddings(
embedded_documents, self.similarity_fn, self.similarity_threshold
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""

@ -0,0 +1,67 @@
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
from langchain.schema.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
_message_from_dict,
_message_to_dict,
get_buffer_string,
messages_from_dict,
messages_to_dict,
)
from langchain.schema.output import (
ChatGeneration,
ChatResult,
Generation,
LLMResult,
RunInfo,
)
from langchain.schema.output_parser import (
BaseLLMOutputParser,
BaseOutputParser,
NoOpOutputParser,
OutputParserException,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.retriever import BaseRetriever
RUN_KEY = "__run"
Memory = BaseMemory
__all__ = [
"BaseMemory",
"BaseChatMessageHistory",
"AgentFinish",
"AgentAction",
"Document",
"BaseDocumentTransformer",
"BaseMessage",
"ChatMessage",
"FunctionMessage",
"HumanMessage",
"AIMessage",
"SystemMessage",
"messages_from_dict",
"messages_to_dict",
"_message_to_dict",
"_message_from_dict",
"get_buffer_string",
"RunInfo",
"LLMResult",
"ChatResult",
"ChatGeneration",
"Generation",
"PromptValue",
"BaseRetriever",
"RUN_KEY",
"Memory",
"OutputParserException",
"NoOpOutputParser",
"BaseOutputParser",
"BaseLLMOutputParser",
]

@ -0,0 +1,25 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import NamedTuple, Union
@dataclass
class AgentAction:
"""A full description of an action for an ActionAgent to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action."""
class AgentFinish(NamedTuple):
"""The final return value of an ActionAgent."""
return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value"""

@ -0,0 +1,82 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Sequence
from pydantic import Field
from langchain.load.serializable import Serializable
class Document(Serializable):
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.
A document transformation system takes a sequence of Documents and returns a
sequence of transformed Documents.
Example:
.. code-block:: python
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95
class Config:
arbitrary_types_allowed = True
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings, stateful_documents
)
included_idxs = _filter_similar_embeddings(
embedded_documents, self.similarity_fn, self.similarity_threshold
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""

@ -0,0 +1,121 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from langchain.load.serializable import Serializable
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
class BaseMemory(Serializable, ABC):
"""Base abstract class for memory in Chains.
Memory refers to state in Chains. Memory can be used to store information about
past executions of a Chain and inject that information into the inputs of
future executions of the Chain. For example, for conversational Chains Memory
can be used to store conversations and automatically add them to future model
prompts so that the model has the necessary context to respond coherently to
the latest input.
Example:
.. code-block:: python
class SimpleMemory(BaseMemory):
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
pass
def clear(self) -> None:
pass
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this chain run to memory."""
@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
Example:
.. code-block:: python
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
def add_user_message(self, message: str) -> None:
"""Convenience method for adding a human message string to the store.
Args:
message: The string contents of a human message.
"""
self.add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Convenience method for adding an AI message string to the store.
Args:
message: The string contents of an AI message.
"""
self.add_message(AIMessage(content=message))
# TODO: Make this an abstractmethod.
def add_message(self, message: BaseMessage) -> None:
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""

@ -0,0 +1,183 @@
from __future__ import annotations
from abc import abstractmethod
from typing import List, Sequence
from pydantic import Field
from langchain.load.serializable import Serializable
def get_buffer_string(
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain.schema import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, FunctionMessage):
role = "Function"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
message = f"{role}: {m.content}"
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
message += f"{m.additional_kwargs['function_call']}"
string_messages.append(message)
return "\n".join(string_messages)
class BaseMessage(Serializable):
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: str
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
@property
@abstractmethod
def type(self) -> str:
"""Type of the Message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""Whether this class is LangChain serializable."""
return True
class HumanMessage(BaseMessage):
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class AIMessage(BaseMessage):
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat"
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of messages as dicts.
"""
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
"""
return [_message_from_dict(m) for m in messages]

@ -0,0 +1,118 @@
from __future__ import annotations
from copy import deepcopy
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel, root_validator
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
class Generation(Serializable):
"""A single text generation output."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
# TODO: add log probs as separate attribute
@property
def lc_serializable(self) -> bool:
"""Whether this class is LangChain serializable."""
return True
class ChatGeneration(Generation):
"""A single chat generation output."""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
values["text"] = values["message"].content
return values
class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model."""
run_id: UUID
"""A unique identifier for the model or chain run."""
class ChatResult(BaseModel):
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
"""Arbitrary LLM provider-specific output."""
run: Optional[List[RunInfo]] = None
"""List of metadata info for model call for each input."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
contains only a single Generation. If token usage information is available,
it is kept only for the LLMResult corresponding to the top-choice
Generation, to avoid over-counting of token usage downstream.
Returns:
List of LLMResults where each returned LLMResult contains a single
Generation.
"""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
"""Check for LLMResult equality by ignoring any metadata related to runs."""
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)

@ -0,0 +1,172 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar
from langchain.load.serializable import Serializable
from langchain.schema.output import Generation
from langchain.schema.prompt import PromptValue
T = TypeVar("T")
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
Example:
.. code-block:: python
class BooleanOutputParser(BaseOutputParser[bool]):
true_val: str = "YES"
false_val: str = "NO"
def parse(self, text: str) -> bool:
cleaned_text = text.strip().upper()
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
raise OutputParserException(
f"BooleanOutputParser expected output value to either be "
f"{self.true_val} or {self.false_val} (case-insensitive). "
f"Received {cleaned_text}."
)
return cleaned_text == self.true_val.upper()
@property
def _type(self) -> str:
return "boolean_output_parser"
""" # noqa: E501
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return self.parse(result[0].text)
@abstractmethod
def parse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of language model.
Returns:
Structured output.
"""
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Parse the output of an LLM call with the input prompt for context.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
Args:
completion: String output of language model.
prompt: Input PromptValue.
Returns:
Structured output
"""
return self.parse(completion)
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict(**kwargs)
output_parser_dict["_type"] = self._type
return output_parser_dict
class NoOpOutputParser(BaseOutputParser[str]):
"""'No operation' OutputParser that returns the text as is."""
@property
def lc_serializable(self) -> bool:
"""Whether the class LangChain serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
class OutputParserException(ValueError):
"""Exception that output parsers should raise to signify a parsing error.
This exists to differentiate parsing errors from other code or execution errors
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
Args:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
llm_output: String model output which is error-ing.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
"""
def __init__(
self,
error: Any,
observation: Optional[str] = None,
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
if send_to_llm:
if observation is None or llm_output is None:
raise ValueError(
"Arguments 'observation' & 'llm_output'"
" are required if 'send_to_llm' is True"
)
self.observation = observation
self.llm_output = llm_output
self.send_to_llm = send_to_llm

@ -0,0 +1,23 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
@abstractmethod
def to_string(self) -> str:
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""

@ -0,0 +1,191 @@
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, List
from langchain.schema.document import Document
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Callbacks,
)
class BaseRetriever(ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source.
Example:
.. code-block:: python
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
query_vec = self.vectorizer.transform([query])
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
""" # noqa: E501
_new_arg_supported: bool = False
_expects_other_args: bool = False
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for.
run_manager: The callbacks handler to use.
Returns:
List of relevant documents
"""
@abstractmethod
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
def get_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: String to find relevant documents for.
callbacks: Callback manager or list of callbacks.
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **kwargs
)
elif self._expects_other_args:
result = self._get_relevant_documents(query, **kwargs)
else:
result = self._get_relevant_documents(query) # type: ignore[call-arg]
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
async def aget_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = await callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **kwargs
)
elif self._expects_other_args:
result = await self._aget_relevant_documents(query, **kwargs)
else:
result = await self._aget_relevant_documents(
query, # type: ignore[call-arg]
)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result

@ -6,12 +6,10 @@ import pytest
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

@ -8,13 +8,11 @@ import pytest
from langchain.chat_models import ChatGooglePalm
from langchain.schema import (
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
SystemMessage,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
def test_chat_google_palm() -> None:

@ -6,13 +6,11 @@ import pytest
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema import (
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
SystemMessage,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

@ -5,13 +5,11 @@ import pytest
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.schema import (
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
SystemMessage,
)
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

@ -13,11 +13,7 @@ import pytest
from langchain.chat_models import ChatVertexAI
from langchain.chat_models.vertexai import _MessagePair, _parse_chat_history
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
def test_vertexai_single_call() -> None:

@ -8,10 +8,7 @@ from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
from langchain.schema import (
AIMessage,
HumanMessage,
)
from langchain.schema.messages import AIMessage, HumanMessage
def _chat_message_history(

@ -3,7 +3,7 @@ import os
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory
from langchain.schema import _message_to_dict
from langchain.schema.messages import _message_to_dict
# Replace these with your Azure Cosmos DB endpoint and key
endpoint = os.environ["COSMOS_DB_ENDPOINT"]

@ -2,7 +2,7 @@ import json
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import FirestoreChatMessageHistory
from langchain.schema import _message_to_dict
from langchain.schema.messages import _message_to_dict
def test_memory_with_message_store() -> None:

@ -14,7 +14,7 @@ from momento import CacheClient, Configurations, CredentialProvider
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import MomentoChatMessageHistory
from langchain.schema import _message_to_dict
from langchain.schema.messages import _message_to_dict
def random_string() -> str:

@ -3,7 +3,7 @@ import os
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import MongoDBChatMessageHistory
from langchain.schema import _message_to_dict
from langchain.schema.messages import _message_to_dict
# Replace these with your mongodb connection string
connection_string = os.environ["MONGODB_CONNECTION_STRING"]

@ -2,7 +2,7 @@ import json
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain.schema import _message_to_dict
from langchain.schema.messages import _message_to_dict
def test_memory_with_message_store() -> None:

@ -6,7 +6,7 @@ from uuid import UUID
from pydantic import BaseModel
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessage
class BaseFakeCallbackHandler(BaseModel):

@ -11,7 +11,8 @@ from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import HumanMessage, LLMResult
from langchain.schema import LLMResult
from langchain.schema.messages import HumanMessage
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}

@ -18,7 +18,8 @@ from langchain.callbacks.tracers.langchain_v1 import (
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
from langchain.schema import HumanMessage, LLMResult
from langchain.schema import LLMResult
from langchain.schema.messages import HumanMessage
TEST_SESSION_ID = 2023

@ -7,11 +7,7 @@ from langchain.chat_models.google_palm import (
ChatGooglePalmError,
_messages_to_prompt_dict,
)
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
def test_messages_to_prompt_dict_with_valid_messages() -> None:

@ -5,9 +5,7 @@ import json
from langchain.chat_models.openai import (
_convert_dict_to_message,
)
from langchain.schema import (
FunctionMessage,
)
from langchain.schema.messages import FunctionMessage
def test_function_message_dict_to_function_message() -> None:

@ -6,7 +6,8 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import AIMessage, BaseMessage
class FakeChatModel(SimpleChatModel):

@ -1,7 +1,7 @@
"""Test LLM callbacks."""
from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM
from langchain.schema import HumanMessage
from langchain.schema.messages import HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeCallbackHandler,
FakeCallbackHandlerWithChatStart,

@ -5,7 +5,7 @@ from typing import Generator
import pytest
from langchain.memory.chat_message_histories import FileChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
from langchain.schema.messages import AIMessage, HumanMessage
@pytest.fixture

@ -4,7 +4,7 @@ from typing import Tuple
import pytest
from langchain.memory.chat_message_histories import SQLChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
from langchain.schema.messages import AIMessage, HumanMessage
# @pytest.fixture(params=[("SQLite"), ("postgresql")])

@ -4,7 +4,7 @@ import pytest
from pytest_mock import MockerFixture
from langchain.memory.chat_message_histories import ZepChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
from langchain.schema.messages import AIMessage, HumanMessage
if TYPE_CHECKING:
from zep_python import ZepClient

@ -13,7 +13,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import HumanMessage
from langchain.schema.messages import HumanMessage
def create_messages() -> List[BaseMessagePromptTemplate]:

@ -16,12 +16,10 @@ from langchain.chat_models.base import BaseChatModel, dumps
from langchain.llms import FakeListLLM
from langchain.llms.base import BaseLLM
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
Generation,
HumanMessage,
)
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
def get_sqlite_cache() -> SQLAlchemyCache:

@ -2,7 +2,7 @@
import unittest
from langchain.schema import (
from langchain.schema.messages import (
AIMessage,
HumanMessage,
SystemMessage,

Loading…
Cancel
Save