mirror of https://github.com/hwchase17/langchain
Harrison/split schema dir (#7025)
should be no functional changes also keep __init__ exposing a lot for backwards compat --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/7028/head
parent
556c425042
commit
3bfe7cf467
@ -1,886 +0,0 @@
|
||||
"""Common schema objects."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
RUN_KEY = "__run"
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Convert sequence of Messages to strings and concatenate them into one string.
|
||||
|
||||
Args:
|
||||
messages: Messages to be converted to strings.
|
||||
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||
|
||||
Returns:
|
||||
A single string concatenation of all input messages.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hi, how are you?"),
|
||||
AIMessage(content="Good, how are you?"),
|
||||
]
|
||||
get_buffer_string(messages)
|
||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||
"""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
elif isinstance(m, FunctionMessage):
|
||||
role = "Function"
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
message = f"{role}: {m.content}"
|
||||
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||
message += f"{m.additional_kwargs['function_call']}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""A full description of an action for an ActionAgent to execute."""
|
||||
|
||||
tool: str
|
||||
"""The name of the Tool to execute."""
|
||||
tool_input: Union[str, dict]
|
||||
"""The input to pass in to the Tool."""
|
||||
log: str
|
||||
"""Additional information to log about the action."""
|
||||
|
||||
|
||||
class AgentFinish(NamedTuple):
|
||||
"""The final return value of an ActionAgent."""
|
||||
|
||||
return_values: dict
|
||||
"""Dictionary of return values."""
|
||||
log: str
|
||||
"""Additional information to log about the return value"""
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""A single text generation output."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw response from the provider. May include things like the
|
||||
reason for finishing or token log probabilities.
|
||||
"""
|
||||
# TODO: add log probs as separate attribute
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class BaseMessage(Serializable):
|
||||
"""The base abstract Message class.
|
||||
|
||||
Messages are the inputs and outputs of ChatModels.
|
||||
"""
|
||||
|
||||
content: str
|
||||
"""The string contents of the message."""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Any additional information."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""Type of the Message, used for serialization."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""A Message from a human."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "human"
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""A Message from an AI."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "ai"
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||
of input messages.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "system"
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
"""A Message for passing the result of executing a function back to a model."""
|
||||
|
||||
name: str
|
||||
"""The name of the function that was executed."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "function"
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||
|
||||
role: str
|
||||
"""The speaker / role of the Message."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "chat"
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||
"""Convert a sequence of Messages to a list of dictionaries.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as BaseMessages) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages as dicts.
|
||||
"""
|
||||
return [_message_to_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
elif _type == "ai":
|
||||
return AIMessage(**message["data"])
|
||||
elif _type == "system":
|
||||
return SystemMessage(**message["data"])
|
||||
elif _type == "chat":
|
||||
return ChatMessage(**message["data"])
|
||||
else:
|
||||
raise ValueError(f"Got unexpected type: {_type}")
|
||||
|
||||
|
||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages from dicts to Message objects.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as dicts) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output."""
|
||||
|
||||
text: str = ""
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
|
||||
@root_validator
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
values["text"] = values["message"].content
|
||||
return values
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||
|
||||
run_id: UUID
|
||||
"""A unique identifier for the model or chain run."""
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
"""Class that contains all results for a single chat model call."""
|
||||
|
||||
generations: List[ChatGeneration]
|
||||
"""List of the chat generations. This is a List because an input can have multiple
|
||||
candidate generations.
|
||||
"""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""Class that contains all results for a batched LLM call."""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
"""List of generated outputs. This is a List[List[]] because
|
||||
each input could have multiple candidate generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""Arbitrary LLM provider-specific output."""
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""List of metadata info for model call for each input."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list.
|
||||
|
||||
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
||||
contains only a single Generation. If token usage information is available,
|
||||
it is kept only for the LLMResult corresponding to the top-choice
|
||||
Generation, to avoid over-counting of token usage downstream.
|
||||
|
||||
Returns:
|
||||
List of LLMResults where each returned LLMResult contains a single
|
||||
Generation.
|
||||
"""
|
||||
llm_results = []
|
||||
for i, gen_list in enumerate(self.generations):
|
||||
# Avoid double counting tokens in OpenAICallback
|
||||
if i == 0:
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.llm_output is not None:
|
||||
llm_output = deepcopy(self.llm_output)
|
||||
llm_output["token_usage"] = dict()
|
||||
else:
|
||||
llm_output = None
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
)
|
||||
return llm_results
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.generations == other.generations
|
||||
and self.llm_output == other.llm_output
|
||||
)
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
|
||||
|
||||
class BaseMemory(Serializable, ABC):
|
||||
"""Base abstract class for memory in Chains.
|
||||
|
||||
Memory refers to state in Chains. Memory can be used to store information about
|
||||
past executions of a Chain and inject that information into the inputs of
|
||||
future executions of the Chain. For example, for conversational Chains Memory
|
||||
can be used to store conversations and automatically add them to future model
|
||||
prompts so that the model has the necessary context to respond coherently to
|
||||
the latest input.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
|
||||
|
||||
class BaseChatMessageHistory(ABC):
|
||||
"""Abstract base class for storing chat message history.
|
||||
|
||||
See `ChatMessageHistory` for default implementation.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
storage_path: str
|
||||
session_id: str
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
||||
messages = json.loads(f.read())
|
||||
return messages_from_dict(messages)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
messages = self.messages.append(_message_to_dict(message))
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
json.dump(f, messages)
|
||||
|
||||
def clear(self):
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
f.write("[]")
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
"""A list of Messages stored in-memory."""
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Convenience method for adding a human message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of a human message.
|
||||
"""
|
||||
self.add_message(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Convenience method for adding an AI message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of an AI message.
|
||||
"""
|
||||
self.add_message(AIMessage(content=message))
|
||||
|
||||
# TODO: Make this an abstractmethod.
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the store.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
|
||||
|
||||
class Document(Serializable):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
"""String text."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
A retrieval system is defined as something that can take string queries and return
|
||||
the most 'relevant' Documents from some source.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
vectorizer: Any
|
||||
docs: List[Document]
|
||||
tfidf_array: Any
|
||||
k: int = 4
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
||||
query_vec = self.vectorizer.transform([query])
|
||||
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Version upgrade for old retrievers that implemented the public
|
||||
# methods directly.
|
||||
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.get_relevant_documents
|
||||
)
|
||||
cls._get_relevant_documents = swap # type: ignore[assignment]
|
||||
if (
|
||||
hasattr(cls, "aget_relevant_documents")
|
||||
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.aget_relevant_documents
|
||||
)
|
||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._get_relevant_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
# If a V1 retriever broke the interface and expects additional arguments
|
||||
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for.
|
||||
run_manager: The callbacks handler to use.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for.
|
||||
callbacks: Callback manager or list of callbacks.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = self._get_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = self._get_relevant_documents(query) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = await self._aget_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, # type: ignore[call-arg]
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
Memory = BaseMemory
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
||||
"""Class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class BooleanOutputParser(BaseOutputParser[bool]):
|
||||
true_val: str = "YES"
|
||||
false_val: str = "NO"
|
||||
|
||||
def parse(self, text: str) -> bool:
|
||||
cleaned_text = text.strip().upper()
|
||||
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
|
||||
raise OutputParserException(
|
||||
f"BooleanOutputParser expected output value to either be "
|
||||
f"{self.true_val} or {self.false_val} (case-insensitive). "
|
||||
f"Received {cleaned_text}."
|
||||
)
|
||||
return cleaned_text == self.true_val.upper()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "boolean_output_parser"
|
||||
""" # noqa: E501
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
# TODO: rename 'completion' -> 'text'.
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
"""Parse the output of an LLM call with the input prompt for context.
|
||||
|
||||
The prompt is largely provided in the event the OutputParser wants
|
||||
to retry or fix the output in some way, and needs information from
|
||||
the prompt to do so.
|
||||
|
||||
Args:
|
||||
completion: String output of language model.
|
||||
prompt: Input PromptValue.
|
||||
|
||||
Returns:
|
||||
Structured output
|
||||
"""
|
||||
return self.parse(completion)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
raise NotImplementedError(
|
||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
" This is required for serialization."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class NoOpOutputParser(BaseOutputParser[str]):
|
||||
"""'No operation' OutputParser that returns the text as is."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Whether the class LangChain serializable."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
return "default"
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
||||
|
||||
class OutputParserException(ValueError):
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
This exists to differentiate parsing errors from other code or execution errors
|
||||
that also may arise inside the output parser. OutputParserExceptions will be
|
||||
available to catch and handle in ways to fix the parsing error, while other
|
||||
errors will be raised.
|
||||
|
||||
Args:
|
||||
error: The error that's being re-raised or an error message.
|
||||
observation: String explanation of error which can be passed to a
|
||||
model to try and remediate the issue.
|
||||
llm_output: String model output which is error-ing.
|
||||
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
||||
after an OutputParserException has been raised. This gives the underlying
|
||||
model driving the agent the context that the previous output was improperly
|
||||
structured, in the hopes that it will update the output to the correct
|
||||
format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: Any,
|
||||
observation: Optional[str] = None,
|
||||
llm_output: Optional[str] = None,
|
||||
send_to_llm: bool = False,
|
||||
):
|
||||
super(OutputParserException, self).__init__(error)
|
||||
if send_to_llm:
|
||||
if observation is None or llm_output is None:
|
||||
raise ValueError(
|
||||
"Arguments 'observation' & 'llm_output'"
|
||||
" are required if 'send_to_llm' is True"
|
||||
)
|
||||
self.observation = observation
|
||||
self.llm_output = llm_output
|
||||
self.send_to_llm = send_to_llm
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Abstract base class for document transformation systems.
|
||||
|
||||
A document transformation system takes a sequence of Documents and returns a
|
||||
sequence of transformed Documents.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
embeddings: Embeddings
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
similarity_threshold: float = 0.95
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||
self.embeddings, stateful_documents
|
||||
)
|
||||
included_idxs = _filter_similar_embeddings(
|
||||
embedded_documents, self.similarity_fn, self.similarity_threshold
|
||||
)
|
||||
return [stateful_documents[i] for i in sorted(included_idxs)]
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
@ -0,0 +1,67 @@
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
_message_from_dict,
|
||||
_message_to_dict,
|
||||
get_buffer_string,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain.schema.output import (
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
Generation,
|
||||
LLMResult,
|
||||
RunInfo,
|
||||
)
|
||||
from langchain.schema.output_parser import (
|
||||
BaseLLMOutputParser,
|
||||
BaseOutputParser,
|
||||
NoOpOutputParser,
|
||||
OutputParserException,
|
||||
)
|
||||
from langchain.schema.prompt import PromptValue
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
|
||||
RUN_KEY = "__run"
|
||||
Memory = BaseMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"BaseChatMessageHistory",
|
||||
"AgentFinish",
|
||||
"AgentAction",
|
||||
"Document",
|
||||
"BaseDocumentTransformer",
|
||||
"BaseMessage",
|
||||
"ChatMessage",
|
||||
"FunctionMessage",
|
||||
"HumanMessage",
|
||||
"AIMessage",
|
||||
"SystemMessage",
|
||||
"messages_from_dict",
|
||||
"messages_to_dict",
|
||||
"_message_to_dict",
|
||||
"_message_from_dict",
|
||||
"get_buffer_string",
|
||||
"RunInfo",
|
||||
"LLMResult",
|
||||
"ChatResult",
|
||||
"ChatGeneration",
|
||||
"Generation",
|
||||
"PromptValue",
|
||||
"BaseRetriever",
|
||||
"RUN_KEY",
|
||||
"Memory",
|
||||
"OutputParserException",
|
||||
"NoOpOutputParser",
|
||||
"BaseOutputParser",
|
||||
"BaseLLMOutputParser",
|
||||
]
|
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""A full description of an action for an ActionAgent to execute."""
|
||||
|
||||
tool: str
|
||||
"""The name of the Tool to execute."""
|
||||
tool_input: Union[str, dict]
|
||||
"""The input to pass in to the Tool."""
|
||||
log: str
|
||||
"""Additional information to log about the action."""
|
||||
|
||||
|
||||
class AgentFinish(NamedTuple):
|
||||
"""The final return value of an ActionAgent."""
|
||||
|
||||
return_values: dict
|
||||
"""Dictionary of return values."""
|
||||
log: str
|
||||
"""Additional information to log about the return value"""
|
@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
|
||||
class Document(Serializable):
|
||||
"""Class for storing a piece of text and associated metadata."""
|
||||
|
||||
page_content: str
|
||||
"""String text."""
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||
documents, etc.).
|
||||
"""
|
||||
|
||||
|
||||
class BaseDocumentTransformer(ABC):
|
||||
"""Abstract base class for document transformation systems.
|
||||
|
||||
A document transformation system takes a sequence of Documents and returns a
|
||||
sequence of transformed Documents.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||
embeddings: Embeddings
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
similarity_threshold: float = 0.95
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
stateful_documents = get_stateful_documents(documents)
|
||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||
self.embeddings, stateful_documents
|
||||
)
|
||||
included_idxs = _filter_similar_embeddings(
|
||||
embedded_documents, self.similarity_fn, self.similarity_threshold
|
||||
)
|
||||
return [stateful_documents[i] for i in sorted(included_idxs)]
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
@abstractmethod
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Asynchronously transform a list of documents.
|
||||
|
||||
Args:
|
||||
documents: A sequence of Documents to be transformed.
|
||||
|
||||
Returns:
|
||||
A list of transformed Documents.
|
||||
"""
|
@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
|
||||
|
||||
class BaseMemory(Serializable, ABC):
|
||||
"""Base abstract class for memory in Chains.
|
||||
|
||||
Memory refers to state in Chains. Memory can be used to store information about
|
||||
past executions of a Chain and inject that information into the inputs of
|
||||
future executions of the Chain. For example, for conversational Chains Memory
|
||||
can be used to store conversations and automatically add them to future model
|
||||
prompts so that the model has the necessary context to respond coherently to
|
||||
the latest input.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
memories: Dict[str, Any] = dict()
|
||||
|
||||
@property
|
||||
def memory_variables(self) -> List[str]:
|
||||
return list(self.memories.keys())
|
||||
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.memories
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
""" # noqa: E501
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def memory_variables(self) -> List[str]:
|
||||
"""The string keys this memory class will add to chain inputs."""
|
||||
|
||||
@abstractmethod
|
||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
|
||||
@abstractmethod
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||
"""Save the context of this chain run to memory."""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Clear memory contents."""
|
||||
|
||||
|
||||
class BaseChatMessageHistory(ABC):
|
||||
"""Abstract base class for storing chat message history.
|
||||
|
||||
See `ChatMessageHistory` for default implementation.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||
storage_path: str
|
||||
session_id: str
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
||||
messages = json.loads(f.read())
|
||||
return messages_from_dict(messages)
|
||||
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
messages = self.messages.append(_message_to_dict(message))
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
json.dump(f, messages)
|
||||
|
||||
def clear(self):
|
||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||
f.write("[]")
|
||||
"""
|
||||
|
||||
messages: List[BaseMessage]
|
||||
"""A list of Messages stored in-memory."""
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Convenience method for adding a human message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of a human message.
|
||||
"""
|
||||
self.add_message(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Convenience method for adding an AI message string to the store.
|
||||
|
||||
Args:
|
||||
message: The string contents of an AI message.
|
||||
"""
|
||||
self.add_message(AIMessage(content=message))
|
||||
|
||||
# TODO: Make this an abstractmethod.
|
||||
def add_message(self, message: BaseMessage) -> None:
|
||||
"""Add a Message object to the store.
|
||||
|
||||
Args:
|
||||
message: A BaseMessage object to store.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Convert sequence of Messages to strings and concatenate them into one string.
|
||||
|
||||
Args:
|
||||
messages: Messages to be converted to strings.
|
||||
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||
|
||||
Returns:
|
||||
A single string concatenation of all input messages.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.schema import AIMessage, HumanMessage
|
||||
|
||||
messages = [
|
||||
HumanMessage(content="Hi, how are you?"),
|
||||
AIMessage(content="Good, how are you?"),
|
||||
]
|
||||
get_buffer_string(messages)
|
||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||
"""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
elif isinstance(m, FunctionMessage):
|
||||
role = "Function"
|
||||
elif isinstance(m, ChatMessage):
|
||||
role = m.role
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
message = f"{role}: {m.content}"
|
||||
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||
message += f"{m.additional_kwargs['function_call']}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
class BaseMessage(Serializable):
|
||||
"""The base abstract Message class.
|
||||
|
||||
Messages are the inputs and outputs of ChatModels.
|
||||
"""
|
||||
|
||||
content: str
|
||||
"""The string contents of the message."""
|
||||
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
"""Any additional information."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""Type of the Message, used for serialization."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""A Message from a human."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "human"
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""A Message from an AI."""
|
||||
|
||||
example: bool = False
|
||||
"""Whether this Message is being passed in to the model as part of an example
|
||||
conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "ai"
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||
of input messages.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "system"
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
"""A Message for passing the result of executing a function back to a model."""
|
||||
|
||||
name: str
|
||||
"""The name of the function that was executed."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "function"
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||
|
||||
role: str
|
||||
"""The speaker / role of the Message."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "chat"
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||
"""Convert a sequence of Messages to a list of dictionaries.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as BaseMessages) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages as dicts.
|
||||
"""
|
||||
return [_message_to_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
elif _type == "ai":
|
||||
return AIMessage(**message["data"])
|
||||
elif _type == "system":
|
||||
return SystemMessage(**message["data"])
|
||||
elif _type == "chat":
|
||||
return ChatMessage(**message["data"])
|
||||
else:
|
||||
raise ValueError(f"Got unexpected type: {_type}")
|
||||
|
||||
|
||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
"""Convert a sequence of messages from dicts to Message objects.
|
||||
|
||||
Args:
|
||||
messages: Sequence of messages (as dicts) to convert.
|
||||
|
||||
Returns:
|
||||
List of messages (BaseMessages).
|
||||
"""
|
||||
return [_message_from_dict(m) for m in messages]
|
@ -0,0 +1,118 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""A single text generation output."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw response from the provider. May include things like the
|
||||
reason for finishing or token log probabilities.
|
||||
"""
|
||||
# TODO: add log probs as separate attribute
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""A single chat generation output."""
|
||||
|
||||
text: str = ""
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
|
||||
@root_validator
|
||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Set the text attribute to be the contents of the message."""
|
||||
values["text"] = values["message"].content
|
||||
return values
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||
|
||||
run_id: UUID
|
||||
"""A unique identifier for the model or chain run."""
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
"""Class that contains all results for a single chat model call."""
|
||||
|
||||
generations: List[ChatGeneration]
|
||||
"""List of the chat generations. This is a List because an input can have multiple
|
||||
candidate generations.
|
||||
"""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""Class that contains all results for a batched LLM call."""
|
||||
|
||||
generations: List[List[Generation]]
|
||||
"""List of generated outputs. This is a List[List[]] because
|
||||
each input could have multiple candidate generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""Arbitrary LLM provider-specific output."""
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""List of metadata info for model call for each input."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list.
|
||||
|
||||
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
||||
contains only a single Generation. If token usage information is available,
|
||||
it is kept only for the LLMResult corresponding to the top-choice
|
||||
Generation, to avoid over-counting of token usage downstream.
|
||||
|
||||
Returns:
|
||||
List of LLMResults where each returned LLMResult contains a single
|
||||
Generation.
|
||||
"""
|
||||
llm_results = []
|
||||
for i, gen_list in enumerate(self.generations):
|
||||
# Avoid double counting tokens in OpenAICallback
|
||||
if i == 0:
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.llm_output is not None:
|
||||
llm_output = deepcopy(self.llm_output)
|
||||
llm_output["token_usage"] = dict()
|
||||
else:
|
||||
llm_output = None
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
)
|
||||
return llm_results
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.generations == other.generations
|
||||
and self.llm_output == other.llm_output
|
||||
)
|
@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.output import Generation
|
||||
from langchain.schema.prompt import PromptValue
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
||||
"""Abstract base class for parsing the outputs of a model."""
|
||||
|
||||
@abstractmethod
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
|
||||
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
||||
"""Class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class BooleanOutputParser(BaseOutputParser[bool]):
|
||||
true_val: str = "YES"
|
||||
false_val: str = "NO"
|
||||
|
||||
def parse(self, text: str) -> bool:
|
||||
cleaned_text = text.strip().upper()
|
||||
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
|
||||
raise OutputParserException(
|
||||
f"BooleanOutputParser expected output value to either be "
|
||||
f"{self.true_val} or {self.false_val} (case-insensitive). "
|
||||
f"Received {cleaned_text}."
|
||||
)
|
||||
return cleaned_text == self.true_val.upper()
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "boolean_output_parser"
|
||||
""" # noqa: E501
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
"""Parse a list of candidate model Generations into a specific format.
|
||||
|
||||
The return value is parsed from only the first Generation in the result, which
|
||||
is assumed to be the highest-likelihood Generation.
|
||||
|
||||
Args:
|
||||
result: A list of Generations to be parsed. The Generations are assumed
|
||||
to be different candidate outputs for a single model input.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
return self.parse(result[0].text)
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> T:
|
||||
"""Parse a single string model output into some structure.
|
||||
|
||||
Args:
|
||||
text: String output of language model.
|
||||
|
||||
Returns:
|
||||
Structured output.
|
||||
"""
|
||||
|
||||
# TODO: rename 'completion' -> 'text'.
|
||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||
"""Parse the output of an LLM call with the input prompt for context.
|
||||
|
||||
The prompt is largely provided in the event the OutputParser wants
|
||||
to retry or fix the output in some way, and needs information from
|
||||
the prompt to do so.
|
||||
|
||||
Args:
|
||||
completion: String output of language model.
|
||||
prompt: Input PromptValue.
|
||||
|
||||
Returns:
|
||||
Structured output
|
||||
"""
|
||||
return self.parse(completion)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
raise NotImplementedError(
|
||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
" This is required for serialization."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict(**kwargs)
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
||||
|
||||
class NoOpOutputParser(BaseOutputParser[str]):
|
||||
"""'No operation' OutputParser that returns the text as is."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Whether the class LangChain serializable."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""Return the output parser type for serialization."""
|
||||
return "default"
|
||||
|
||||
def parse(self, text: str) -> str:
|
||||
"""Returns the input text with no changes."""
|
||||
return text
|
||||
|
||||
|
||||
class OutputParserException(ValueError):
|
||||
"""Exception that output parsers should raise to signify a parsing error.
|
||||
|
||||
This exists to differentiate parsing errors from other code or execution errors
|
||||
that also may arise inside the output parser. OutputParserExceptions will be
|
||||
available to catch and handle in ways to fix the parsing error, while other
|
||||
errors will be raised.
|
||||
|
||||
Args:
|
||||
error: The error that's being re-raised or an error message.
|
||||
observation: String explanation of error which can be passed to a
|
||||
model to try and remediate the issue.
|
||||
llm_output: String model output which is error-ing.
|
||||
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
||||
after an OutputParserException has been raised. This gives the underlying
|
||||
model driving the agent the context that the previous output was improperly
|
||||
structured, in the hopes that it will update the output to the correct
|
||||
format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error: Any,
|
||||
observation: Optional[str] = None,
|
||||
llm_output: Optional[str] = None,
|
||||
send_to_llm: bool = False,
|
||||
):
|
||||
super(OutputParserException, self).__init__(error)
|
||||
if send_to_llm:
|
||||
if observation is None or llm_output is None:
|
||||
raise ValueError(
|
||||
"Arguments 'observation' & 'llm_output'"
|
||||
" are required if 'send_to_llm' is True"
|
||||
)
|
||||
self.observation = observation
|
||||
self.llm_output = llm_output
|
||||
self.send_to_llm = send_to_llm
|
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
|
||||
class PromptValue(Serializable, ABC):
|
||||
"""Base abstract class for inputs to any language model.
|
||||
|
||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||
ChatModel inputs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
from langchain.schema.document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
Callbacks,
|
||||
)
|
||||
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
A retrieval system is defined as something that can take string queries and return
|
||||
the most 'relevant' Documents from some source.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||
vectorizer: Any
|
||||
docs: List[Document]
|
||||
tfidf_array: Any
|
||||
k: int = 4
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
||||
query_vec = self.vectorizer.transform([query])
|
||||
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Version upgrade for old retrievers that implemented the public
|
||||
# methods directly.
|
||||
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||
" instead of `get_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
swap = cls.get_relevant_documents
|
||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.get_relevant_documents
|
||||
)
|
||||
cls._get_relevant_documents = swap # type: ignore[assignment]
|
||||
if (
|
||||
hasattr(cls, "aget_relevant_documents")
|
||||
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
||||
):
|
||||
warnings.warn(
|
||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||
" instead of `aget_relevant_documents`",
|
||||
DeprecationWarning,
|
||||
)
|
||||
aswap = cls.aget_relevant_documents
|
||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||
BaseRetriever.aget_relevant_documents
|
||||
)
|
||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||
parameters = signature(cls._get_relevant_documents).parameters
|
||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||
# If a V1 retriever broke the interface and expects additional arguments
|
||||
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
|
||||
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for.
|
||||
run_manager: The callbacks handler to use.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: String to find relevant documents for.
|
||||
callbacks: Callback manager or list of callbacks.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = self._get_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = self._get_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = self._get_relevant_documents(query) # type: ignore[call-arg]
|
||||
except Exception as e:
|
||||
run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
query,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
if self._new_arg_supported:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, run_manager=run_manager, **kwargs
|
||||
)
|
||||
elif self._expects_other_args:
|
||||
result = await self._aget_relevant_documents(query, **kwargs)
|
||||
else:
|
||||
result = await self._aget_relevant_documents(
|
||||
query, # type: ignore[call-arg]
|
||||
)
|
||||
except Exception as e:
|
||||
await run_manager.on_retriever_error(e)
|
||||
raise e
|
||||
else:
|
||||
await run_manager.on_retriever_end(
|
||||
result,
|
||||
**kwargs,
|
||||
)
|
||||
return result
|
Loading…
Reference in New Issue