Improve docstrings for langchain.schema.py (#6802)

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/7025/head
Davis Chase 1 year ago committed by GitHub
parent 0498dad562
commit 556c425042
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,7 @@ from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from inspect import signature
from typing import (
@ -34,9 +35,30 @@ RUN_KEY = "__run"
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
"""Convert sequence of Messages to strings and concatenate them into one string.
Args:
messages: Messages to be converted to strings.
human_prefix: The prefix to prepend to contents of HumanMessages.
ai_prefix: THe prefix to prepend to contents of AIMessages.
Returns:
A single string concatenation of all input messages.
Example:
.. code-block:: python
from langchain.schema import AIMessage, HumanMessage
messages = [
HumanMessage(content="Hi, how are you?"),
AIMessage(content="Good, how are you?"),
]
get_buffer_string(messages)
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
"""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
@ -61,58 +83,73 @@ def get_buffer_string(
@dataclass
class AgentAction:
"""Agent's action to take."""
"""A full description of an action for an ActionAgent to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action."""
class AgentFinish(NamedTuple):
"""Agent's return value."""
"""The final return value of an ActionAgent."""
return_values: dict
"""Dictionary of return values."""
log: str
"""Additional information to log about the return value"""
class Generation(Serializable):
"""Output of a single generation."""
"""A single text generation output."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
"""Raw response from the provider. May include things like the
reason for finishing or token log probabilities.
"""
# TODO: add log probs as separate attribute
@property
def lc_serializable(self) -> bool:
"""This class is LangChain serializable."""
"""Whether this class is LangChain serializable."""
return True
class BaseMessage(Serializable):
"""Message object."""
"""The base abstract Message class.
Messages are the inputs and outputs of ChatModels.
"""
content: str
"""The string contents of the message."""
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
@property
@abstractmethod
def type(self) -> str:
"""Type of the message, used for serialization."""
"""Type of the Message, used for serialization."""
@property
def lc_serializable(self) -> bool:
"""This class is LangChain serializable."""
"""Whether this class is LangChain serializable."""
return True
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
"""A Message from a human."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
@ -121,9 +158,12 @@ class HumanMessage(BaseMessage):
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
"""A Message from an AI."""
example: bool = False
"""Whether this Message is being passed in to the model as part of an example
conversation.
"""
@property
def type(self) -> str:
@ -132,7 +172,9 @@ class AIMessage(BaseMessage):
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
"""
@property
def type(self) -> str:
@ -141,7 +183,10 @@ class SystemMessage(BaseMessage):
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
@property
def type(self) -> str:
@ -150,9 +195,10 @@ class FunctionMessage(BaseMessage):
class ChatMessage(BaseMessage):
"""Type of message with arbitrary speaker."""
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
role: str
"""The speaker / role of the Message."""
@property
def type(self) -> str:
@ -164,14 +210,14 @@ def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
"""Convert messages to dict.
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
"""Convert a sequence of Messages to a list of dictionaries.
Args:
messages: List of messages to convert.
messages: Sequence of messages (as BaseMessages) to convert.
Returns:
List of dicts.
List of messages as dicts.
"""
return [_message_to_dict(m) for m in messages]
@ -191,10 +237,10 @@ def _message_from_dict(message: dict) -> BaseMessage:
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
"""Convert messages from dict.
"""Convert a sequence of messages from dicts to Message objects.
Args:
messages: List of messages (dicts) to convert.
messages: Sequence of messages (as dicts) to convert.
Returns:
List of messages (BaseMessages).
@ -203,45 +249,61 @@ def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
class ChatGeneration(Generation):
"""Output of a single generation."""
"""A single chat generation output."""
text = ""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
message: BaseMessage
"""The message output by the chat model."""
@root_validator
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set the text attribute to be the contents of the message."""
values["text"] = values["message"].content
return values
class RunInfo(BaseModel):
"""Class that contains all relevant metadata for a Run."""
"""Class that contains metadata for a single execution of a Chain or model."""
run_id: UUID
"""A unique identifier for the model or chain run."""
class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result."""
"""Class that contains all results for a single chat model call."""
generations: List[ChatGeneration]
"""List of the things generated."""
"""List of the chat generations. This is a List because an input can have multiple
candidate generations.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
class LLMResult(BaseModel):
"""Class that contains all relevant information for an LLM Result."""
"""Class that contains all results for a batched LLM call."""
generations: List[List[Generation]]
"""List of the things generated. This is List[List[]] because
each input could have multiple generations."""
"""List of generated outputs. This is a List[List[]] because
each input could have multiple candidate generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
"""Arbitrary LLM provider-specific output."""
run: Optional[List[RunInfo]] = None
"""Run metadata."""
"""List of metadata info for model call for each input."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list."""
"""Flatten generations into a single list.
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
contains only a single Generation. If token usage information is available,
it is kept only for the LLMResult corresponding to the top-choice
Generation, to avoid over-counting of token usage downstream.
Returns:
List of LLMResults where each returned LLMResult contains a single
Generation.
"""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
@ -254,7 +316,7 @@ class LLMResult(BaseModel):
)
else:
if self.llm_output is not None:
llm_output = self.llm_output.copy()
llm_output = deepcopy(self.llm_output)
llm_output["token_usage"] = dict()
else:
llm_output = None
@ -267,6 +329,7 @@ class LLMResult(BaseModel):
return llm_results
def __eq__(self, other: object) -> bool:
"""Check for LLMResult equality by ignoring any metadata related to runs."""
if not isinstance(other, LLMResult):
return NotImplemented
return (
@ -276,17 +339,50 @@ class LLMResult(BaseModel):
class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
"""Return prompt as a list of Messages."""
class BaseMemory(Serializable, ABC):
"""Base interface for memory in chains."""
"""Base abstract class for memory in Chains.
Memory refers to state in Chains. Memory can be used to store information about
past executions of a Chain and inject that information into the inputs of
future executions of the Chain. For example, for conversational Chains Memory
can be used to store conversations and automatically add them to future model
prompts so that the model has the necessary context to respond coherently to
the latest input.
Example:
.. code-block:: python
class SimpleMemory(BaseMemory):
memories: Dict[str, Any] = dict()
@property
def memory_variables(self) -> List[str]:
return list(self.memories.keys())
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.memories
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
pass
def clear(self) -> None:
pass
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
@ -296,18 +392,15 @@ class BaseMemory(Serializable, ABC):
@property
@abstractmethod
def memory_variables(self) -> List[str]:
"""Input keys this memory class will load dynamically."""
"""The string keys this memory class will add to chain inputs."""
@abstractmethod
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain.
If None, return all memories
"""
"""Return key-value pairs given the text input to the chain."""
@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this model run to memory."""
"""Save the context of this chain run to memory."""
@abstractmethod
def clear(self) -> None:
@ -315,11 +408,10 @@ class BaseMemory(Serializable, ABC):
class BaseChatMessageHistory(ABC):
"""Base interface for chat message history
"""Abstract base class for storing chat message history.
See `ChatMessageHistory` for default implementation.
"""
"""
Example:
.. code-block:: python
@ -337,24 +429,38 @@ class BaseChatMessageHistory(ABC):
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
"""Convenience method for adding a human message string to the store.
Args:
message: The string contents of a human message.
"""
self.add_message(HumanMessage(content=message))
def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store"""
"""Convenience method for adding an AI message string to the store.
Args:
message: The string contents of an AI message.
"""
self.add_message(AIMessage(content=message))
# TODO: Make this an abstractmethod.
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
"""Add a Message object to the store.
Args:
message: A BaseMessage object to store.
"""
raise NotImplementedError
@abstractmethod
@ -363,14 +469,47 @@ class BaseChatMessageHistory(ABC):
class Document(Serializable):
"""Interface for interacting with a document."""
"""Class for storing a piece of text and associated metadata."""
page_content: str
"""String text."""
metadata: dict = Field(default_factory=dict)
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
class BaseRetriever(ABC):
"""Base interface for a retriever."""
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source.
Example:
.. code-block:: python
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
query_vec = self.vectorizer.transform([query])
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
""" # noqa: E501
_new_arg_supported: bool = False
_expects_other_args: bool = False
@ -415,8 +554,8 @@ class BaseRetriever(ABC):
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
query: String to find relevant documents for.
run_manager: The callbacks handler to use.
Returns:
List of relevant documents
"""
@ -442,8 +581,8 @@ class BaseRetriever(ABC):
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
query: String to find relevant documents for.
callbacks: Callback manager or list of callbacks.
Returns:
List of relevant documents
"""
@ -517,55 +656,94 @@ class BaseRetriever(ABC):
# For backwards compatibility
Memory = BaseMemory
T = TypeVar("T")
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation]) -> T:
"""Parse LLM Result."""
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
Example:
.. code-block:: python
class BooleanOutputParser(BaseOutputParser[bool]):
true_val: str = "YES"
false_val: str = "NO"
def parse(self, text: str) -> bool:
cleaned_text = text.strip().upper()
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
raise OutputParserException(
f"BooleanOutputParser expected output value to either be "
f"{self.true_val} or {self.false_val} (case-insensitive). "
f"Received {cleaned_text}."
)
return cleaned_text == self.true_val.upper()
@property
def _type(self) -> str:
return "boolean_output_parser"
""" # noqa: E501
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return self.parse(result[0].text)
@abstractmethod
def parse(self, text: str) -> T:
"""Parse the output of an LLM call.
A method which takes in a string (assumed output of a language model )
and parses it into some structure.
"""Parse a single string model output into some structure.
Args:
text: output of language model
text: String output of language model.
Returns:
structured output
Structured output.
"""
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Optional method to parse the output of an LLM call with a prompt.
"""Parse the output of an LLM call with the input prompt for context.
The prompt is largely provided in the event the OutputParser wants
to retry or fix the output in some way, and needs information from
the prompt to do so.
Args:
completion: output of language model
prompt: prompt value
completion: String output of language model.
prompt: Input PromptValue.
Returns:
structured output
Structured output
"""
return self.parse(completion)
@ -575,7 +753,7 @@ class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
@property
def _type(self) -> str:
"""Return the type key."""
"""Return the output parser type for serialization."""
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
@ -583,23 +761,26 @@ class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict = super().dict(**kwargs)
output_parser_dict["_type"] = self._type
return output_parser_dict
class NoOpOutputParser(BaseOutputParser[str]):
"""Output parser that just returns the text as is."""
"""'No operation' OutputParser that returns the text as is."""
@property
def lc_serializable(self) -> bool:
"""Whether the class LangChain serializable."""
return True
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""
return "default"
def parse(self, text: str) -> str:
"""Returns the input text with no changes."""
return text
@ -610,13 +791,24 @@ class OutputParserException(ValueError):
that also may arise inside the output parser. OutputParserExceptions will be
available to catch and handle in ways to fix the parsing error, while other
errors will be raised.
Args:
error: The error that's being re-raised or an error message.
observation: String explanation of error which can be passed to a
model to try and remediate the issue.
llm_output: String model output which is error-ing.
send_to_llm: Whether to send the observation and llm_output back to an Agent
after an OutputParserException has been raised. This gives the underlying
model driving the agent the context that the previous output was improperly
structured, in the hopes that it will update the output to the correct
format.
"""
def __init__(
self,
error: Any,
observation: str | None = None,
llm_output: str | None = None,
observation: Optional[str] = None,
llm_output: Optional[str] = None,
send_to_llm: bool = False,
):
super(OutputParserException, self).__init__(error)
@ -632,16 +824,63 @@ class OutputParserException(ValueError):
class BaseDocumentTransformer(ABC):
"""Base interface for transforming documents."""
"""Abstract base class for document transformation systems.
A document transformation system takes a sequence of Documents and returns a
sequence of transformed Documents.
Example:
.. code-block:: python
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
embeddings: Embeddings
similarity_fn: Callable = cosine_similarity
similarity_threshold: float = 0.95
class Config:
arbitrary_types_allowed = True
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
stateful_documents = get_stateful_documents(documents)
embedded_documents = _get_embeddings_from_stateful_docs(
self.embeddings, stateful_documents
)
included_idxs = _filter_similar_embeddings(
embedded_documents, self.similarity_fn, self.similarity_threshold
)
return [stateful_documents[i] for i in sorted(included_idxs)]
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
""" # noqa: E501
@abstractmethod
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform a list of documents."""
"""Transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""
@abstractmethod
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a list of documents."""
"""Asynchronously transform a list of documents.
Args:
documents: A sequence of Documents to be transformed.
Returns:
A list of transformed Documents.
"""

Loading…
Cancel
Save