mirror of https://github.com/hwchase17/langchain
Harrison/split schema dir (#7025)
should be no functional changes also keep __init__ exposing a lot for backwards compat --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/7028/head
parent
556c425042
commit
3bfe7cf467
@ -1,886 +0,0 @@
|
|||||||
"""Common schema objects."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from copy import deepcopy
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from inspect import signature
|
|
||||||
from typing import (
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator
|
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForRetrieverRun,
|
|
||||||
CallbackManagerForRetrieverRun,
|
|
||||||
Callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
RUN_KEY = "__run"
|
|
||||||
|
|
||||||
|
|
||||||
def get_buffer_string(
|
|
||||||
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
|
||||||
) -> str:
|
|
||||||
"""Convert sequence of Messages to strings and concatenate them into one string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Messages to be converted to strings.
|
|
||||||
human_prefix: The prefix to prepend to contents of HumanMessages.
|
|
||||||
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A single string concatenation of all input messages.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain.schema import AIMessage, HumanMessage
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
HumanMessage(content="Hi, how are you?"),
|
|
||||||
AIMessage(content="Good, how are you?"),
|
|
||||||
]
|
|
||||||
get_buffer_string(messages)
|
|
||||||
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
|
||||||
"""
|
|
||||||
string_messages = []
|
|
||||||
for m in messages:
|
|
||||||
if isinstance(m, HumanMessage):
|
|
||||||
role = human_prefix
|
|
||||||
elif isinstance(m, AIMessage):
|
|
||||||
role = ai_prefix
|
|
||||||
elif isinstance(m, SystemMessage):
|
|
||||||
role = "System"
|
|
||||||
elif isinstance(m, FunctionMessage):
|
|
||||||
role = "Function"
|
|
||||||
elif isinstance(m, ChatMessage):
|
|
||||||
role = m.role
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unsupported message type: {m}")
|
|
||||||
message = f"{role}: {m.content}"
|
|
||||||
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
|
||||||
message += f"{m.additional_kwargs['function_call']}"
|
|
||||||
string_messages.append(message)
|
|
||||||
|
|
||||||
return "\n".join(string_messages)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentAction:
|
|
||||||
"""A full description of an action for an ActionAgent to execute."""
|
|
||||||
|
|
||||||
tool: str
|
|
||||||
"""The name of the Tool to execute."""
|
|
||||||
tool_input: Union[str, dict]
|
|
||||||
"""The input to pass in to the Tool."""
|
|
||||||
log: str
|
|
||||||
"""Additional information to log about the action."""
|
|
||||||
|
|
||||||
|
|
||||||
class AgentFinish(NamedTuple):
|
|
||||||
"""The final return value of an ActionAgent."""
|
|
||||||
|
|
||||||
return_values: dict
|
|
||||||
"""Dictionary of return values."""
|
|
||||||
log: str
|
|
||||||
"""Additional information to log about the return value"""
|
|
||||||
|
|
||||||
|
|
||||||
class Generation(Serializable):
|
|
||||||
"""A single text generation output."""
|
|
||||||
|
|
||||||
text: str
|
|
||||||
"""Generated text output."""
|
|
||||||
|
|
||||||
generation_info: Optional[Dict[str, Any]] = None
|
|
||||||
"""Raw response from the provider. May include things like the
|
|
||||||
reason for finishing or token log probabilities.
|
|
||||||
"""
|
|
||||||
# TODO: add log probs as separate attribute
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_serializable(self) -> bool:
|
|
||||||
"""Whether this class is LangChain serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(Serializable):
|
|
||||||
"""The base abstract Message class.
|
|
||||||
|
|
||||||
Messages are the inputs and outputs of ChatModels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content: str
|
|
||||||
"""The string contents of the message."""
|
|
||||||
|
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
|
||||||
"""Any additional information."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the Message, used for serialization."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_serializable(self) -> bool:
|
|
||||||
"""Whether this class is LangChain serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
|
||||||
"""A Message from a human."""
|
|
||||||
|
|
||||||
example: bool = False
|
|
||||||
"""Whether this Message is being passed in to the model as part of an example
|
|
||||||
conversation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "human"
|
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage):
|
|
||||||
"""A Message from an AI."""
|
|
||||||
|
|
||||||
example: bool = False
|
|
||||||
"""Whether this Message is being passed in to the model as part of an example
|
|
||||||
conversation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "ai"
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage):
|
|
||||||
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
|
||||||
of input messages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "system"
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessage(BaseMessage):
|
|
||||||
"""A Message for passing the result of executing a function back to a model."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""The name of the function that was executed."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "function"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseMessage):
|
|
||||||
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
|
||||||
|
|
||||||
role: str
|
|
||||||
"""The speaker / role of the Message."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "chat"
|
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
|
||||||
return {"type": message.type, "data": message.dict()}
|
|
||||||
|
|
||||||
|
|
||||||
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
|
||||||
"""Convert a sequence of Messages to a list of dictionaries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Sequence of messages (as BaseMessages) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages as dicts.
|
|
||||||
"""
|
|
||||||
return [_message_to_dict(m) for m in messages]
|
|
||||||
|
|
||||||
|
|
||||||
def _message_from_dict(message: dict) -> BaseMessage:
|
|
||||||
_type = message["type"]
|
|
||||||
if _type == "human":
|
|
||||||
return HumanMessage(**message["data"])
|
|
||||||
elif _type == "ai":
|
|
||||||
return AIMessage(**message["data"])
|
|
||||||
elif _type == "system":
|
|
||||||
return SystemMessage(**message["data"])
|
|
||||||
elif _type == "chat":
|
|
||||||
return ChatMessage(**message["data"])
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unexpected type: {_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
|
||||||
"""Convert a sequence of messages from dicts to Message objects.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Sequence of messages (as dicts) to convert.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages (BaseMessages).
|
|
||||||
"""
|
|
||||||
return [_message_from_dict(m) for m in messages]
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
|
||||||
"""A single chat generation output."""
|
|
||||||
|
|
||||||
text: str = ""
|
|
||||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
|
||||||
message: BaseMessage
|
|
||||||
"""The message output by the chat model."""
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Set the text attribute to be the contents of the message."""
|
|
||||||
values["text"] = values["message"].content
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
class RunInfo(BaseModel):
|
|
||||||
"""Class that contains metadata for a single execution of a Chain or model."""
|
|
||||||
|
|
||||||
run_id: UUID
|
|
||||||
"""A unique identifier for the model or chain run."""
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResult(BaseModel):
|
|
||||||
"""Class that contains all results for a single chat model call."""
|
|
||||||
|
|
||||||
generations: List[ChatGeneration]
|
|
||||||
"""List of the chat generations. This is a List because an input can have multiple
|
|
||||||
candidate generations.
|
|
||||||
"""
|
|
||||||
llm_output: Optional[dict] = None
|
|
||||||
"""For arbitrary LLM provider specific output."""
|
|
||||||
|
|
||||||
|
|
||||||
class LLMResult(BaseModel):
|
|
||||||
"""Class that contains all results for a batched LLM call."""
|
|
||||||
|
|
||||||
generations: List[List[Generation]]
|
|
||||||
"""List of generated outputs. This is a List[List[]] because
|
|
||||||
each input could have multiple candidate generations."""
|
|
||||||
llm_output: Optional[dict] = None
|
|
||||||
"""Arbitrary LLM provider-specific output."""
|
|
||||||
run: Optional[List[RunInfo]] = None
|
|
||||||
"""List of metadata info for model call for each input."""
|
|
||||||
|
|
||||||
def flatten(self) -> List[LLMResult]:
|
|
||||||
"""Flatten generations into a single list.
|
|
||||||
|
|
||||||
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
|
||||||
contains only a single Generation. If token usage information is available,
|
|
||||||
it is kept only for the LLMResult corresponding to the top-choice
|
|
||||||
Generation, to avoid over-counting of token usage downstream.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of LLMResults where each returned LLMResult contains a single
|
|
||||||
Generation.
|
|
||||||
"""
|
|
||||||
llm_results = []
|
|
||||||
for i, gen_list in enumerate(self.generations):
|
|
||||||
# Avoid double counting tokens in OpenAICallback
|
|
||||||
if i == 0:
|
|
||||||
llm_results.append(
|
|
||||||
LLMResult(
|
|
||||||
generations=[gen_list],
|
|
||||||
llm_output=self.llm_output,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if self.llm_output is not None:
|
|
||||||
llm_output = deepcopy(self.llm_output)
|
|
||||||
llm_output["token_usage"] = dict()
|
|
||||||
else:
|
|
||||||
llm_output = None
|
|
||||||
llm_results.append(
|
|
||||||
LLMResult(
|
|
||||||
generations=[gen_list],
|
|
||||||
llm_output=llm_output,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return llm_results
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
|
||||||
if not isinstance(other, LLMResult):
|
|
||||||
return NotImplemented
|
|
||||||
return (
|
|
||||||
self.generations == other.generations
|
|
||||||
and self.llm_output == other.llm_output
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(Serializable, ABC):
|
|
||||||
"""Base abstract class for inputs to any language model.
|
|
||||||
|
|
||||||
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
|
||||||
ChatModel inputs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def to_string(self) -> str:
|
|
||||||
"""Return prompt value as string."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def to_messages(self) -> List[BaseMessage]:
|
|
||||||
"""Return prompt as a list of Messages."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMemory(Serializable, ABC):
|
|
||||||
"""Base abstract class for memory in Chains.
|
|
||||||
|
|
||||||
Memory refers to state in Chains. Memory can be used to store information about
|
|
||||||
past executions of a Chain and inject that information into the inputs of
|
|
||||||
future executions of the Chain. For example, for conversational Chains Memory
|
|
||||||
can be used to store conversations and automatically add them to future model
|
|
||||||
prompts so that the model has the necessary context to respond coherently to
|
|
||||||
the latest input.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class SimpleMemory(BaseMemory):
|
|
||||||
memories: Dict[str, Any] = dict()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_variables(self) -> List[str]:
|
|
||||||
return list(self.memories.keys())
|
|
||||||
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
||||||
return self.memories
|
|
||||||
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
pass
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def memory_variables(self) -> List[str]:
|
|
||||||
"""The string keys this memory class will add to chain inputs."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Return key-value pairs given the text input to the chain."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
|
||||||
"""Save the context of this chain run to memory."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Clear memory contents."""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseChatMessageHistory(ABC):
|
|
||||||
"""Abstract base class for storing chat message history.
|
|
||||||
|
|
||||||
See `ChatMessageHistory` for default implementation.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class FileChatMessageHistory(BaseChatMessageHistory):
|
|
||||||
storage_path: str
|
|
||||||
session_id: str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self):
|
|
||||||
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
|
||||||
messages = json.loads(f.read())
|
|
||||||
return messages_from_dict(messages)
|
|
||||||
|
|
||||||
def add_message(self, message: BaseMessage) -> None:
|
|
||||||
messages = self.messages.append(_message_to_dict(message))
|
|
||||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
|
||||||
json.dump(f, messages)
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
|
||||||
f.write("[]")
|
|
||||||
"""
|
|
||||||
|
|
||||||
messages: List[BaseMessage]
|
|
||||||
"""A list of Messages stored in-memory."""
|
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
|
||||||
"""Convenience method for adding a human message string to the store.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The string contents of a human message.
|
|
||||||
"""
|
|
||||||
self.add_message(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
"""Convenience method for adding an AI message string to the store.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The string contents of an AI message.
|
|
||||||
"""
|
|
||||||
self.add_message(AIMessage(content=message))
|
|
||||||
|
|
||||||
# TODO: Make this an abstractmethod.
|
|
||||||
def add_message(self, message: BaseMessage) -> None:
|
|
||||||
"""Add a Message object to the store.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: A BaseMessage object to store.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Remove all messages from the store"""
|
|
||||||
|
|
||||||
|
|
||||||
class Document(Serializable):
|
|
||||||
"""Class for storing a piece of text and associated metadata."""
|
|
||||||
|
|
||||||
page_content: str
|
|
||||||
"""String text."""
|
|
||||||
metadata: dict = Field(default_factory=dict)
|
|
||||||
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
|
||||||
documents, etc.).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseRetriever(ABC):
|
|
||||||
"""Abstract base class for a Document retrieval system.
|
|
||||||
|
|
||||||
A retrieval system is defined as something that can take string queries and return
|
|
||||||
the most 'relevant' Documents from some source.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class TFIDFRetriever(BaseRetriever, BaseModel):
|
|
||||||
vectorizer: Any
|
|
||||||
docs: List[Document]
|
|
||||||
tfidf_array: Any
|
|
||||||
k: int = 4
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
|
|
||||||
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
|
||||||
query_vec = self.vectorizer.transform([query])
|
|
||||||
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
|
||||||
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
|
||||||
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
|
||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
_new_arg_supported: bool = False
|
|
||||||
_expects_other_args: bool = False
|
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
||||||
super().__init_subclass__(**kwargs)
|
|
||||||
# Version upgrade for old retrievers that implemented the public
|
|
||||||
# methods directly.
|
|
||||||
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
|
||||||
warnings.warn(
|
|
||||||
"Retrievers must implement abstract `_get_relevant_documents` method"
|
|
||||||
" instead of `get_relevant_documents`",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
swap = cls.get_relevant_documents
|
|
||||||
cls.get_relevant_documents = ( # type: ignore[assignment]
|
|
||||||
BaseRetriever.get_relevant_documents
|
|
||||||
)
|
|
||||||
cls._get_relevant_documents = swap # type: ignore[assignment]
|
|
||||||
if (
|
|
||||||
hasattr(cls, "aget_relevant_documents")
|
|
||||||
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
|
||||||
" instead of `aget_relevant_documents`",
|
|
||||||
DeprecationWarning,
|
|
||||||
)
|
|
||||||
aswap = cls.aget_relevant_documents
|
|
||||||
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
|
||||||
BaseRetriever.aget_relevant_documents
|
|
||||||
)
|
|
||||||
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
|
||||||
parameters = signature(cls._get_relevant_documents).parameters
|
|
||||||
cls._new_arg_supported = parameters.get("run_manager") is not None
|
|
||||||
# If a V1 retriever broke the interface and expects additional arguments
|
|
||||||
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _get_relevant_documents(
|
|
||||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Get documents relevant to a query.
|
|
||||||
Args:
|
|
||||||
query: String to find relevant documents for.
|
|
||||||
run_manager: The callbacks handler to use.
|
|
||||||
Returns:
|
|
||||||
List of relevant documents
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _aget_relevant_documents(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
*,
|
|
||||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Asynchronously get documents relevant to a query.
|
|
||||||
Args:
|
|
||||||
query: string to find relevant documents for
|
|
||||||
run_manager: The callbacks handler to use
|
|
||||||
Returns:
|
|
||||||
List of relevant documents
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_relevant_documents(
|
|
||||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Retrieve documents relevant to a query.
|
|
||||||
Args:
|
|
||||||
query: String to find relevant documents for.
|
|
||||||
callbacks: Callback manager or list of callbacks.
|
|
||||||
Returns:
|
|
||||||
List of relevant documents
|
|
||||||
"""
|
|
||||||
from langchain.callbacks.manager import CallbackManager
|
|
||||||
|
|
||||||
callback_manager = CallbackManager.configure(
|
|
||||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
|
||||||
)
|
|
||||||
run_manager = callback_manager.on_retriever_start(
|
|
||||||
query,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
if self._new_arg_supported:
|
|
||||||
result = self._get_relevant_documents(
|
|
||||||
query, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
elif self._expects_other_args:
|
|
||||||
result = self._get_relevant_documents(query, **kwargs)
|
|
||||||
else:
|
|
||||||
result = self._get_relevant_documents(query) # type: ignore[call-arg]
|
|
||||||
except Exception as e:
|
|
||||||
run_manager.on_retriever_error(e)
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
run_manager.on_retriever_end(
|
|
||||||
result,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def aget_relevant_documents(
|
|
||||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Asynchronously get documents relevant to a query.
|
|
||||||
Args:
|
|
||||||
query: string to find relevant documents for
|
|
||||||
callbacks: Callback manager or list of callbacks
|
|
||||||
Returns:
|
|
||||||
List of relevant documents
|
|
||||||
"""
|
|
||||||
from langchain.callbacks.manager import AsyncCallbackManager
|
|
||||||
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
|
||||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
|
||||||
)
|
|
||||||
run_manager = await callback_manager.on_retriever_start(
|
|
||||||
query,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
if self._new_arg_supported:
|
|
||||||
result = await self._aget_relevant_documents(
|
|
||||||
query, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
elif self._expects_other_args:
|
|
||||||
result = await self._aget_relevant_documents(query, **kwargs)
|
|
||||||
else:
|
|
||||||
result = await self._aget_relevant_documents(
|
|
||||||
query, # type: ignore[call-arg]
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
await run_manager.on_retriever_error(e)
|
|
||||||
raise e
|
|
||||||
else:
|
|
||||||
await run_manager.on_retriever_end(
|
|
||||||
result,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# For backwards compatibility
|
|
||||||
Memory = BaseMemory
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
|
||||||
"""Abstract base class for parsing the outputs of a model."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def parse_result(self, result: List[Generation]) -> T:
|
|
||||||
"""Parse a list of candidate model Generations into a specific format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: A list of Generations to be parsed. The Generations are assumed
|
|
||||||
to be different candidate outputs for a single model input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Structured output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
|
||||||
"""Class to parse the output of an LLM call.
|
|
||||||
|
|
||||||
Output parsers help structure language model responses.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class BooleanOutputParser(BaseOutputParser[bool]):
|
|
||||||
true_val: str = "YES"
|
|
||||||
false_val: str = "NO"
|
|
||||||
|
|
||||||
def parse(self, text: str) -> bool:
|
|
||||||
cleaned_text = text.strip().upper()
|
|
||||||
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
|
|
||||||
raise OutputParserException(
|
|
||||||
f"BooleanOutputParser expected output value to either be "
|
|
||||||
f"{self.true_val} or {self.false_val} (case-insensitive). "
|
|
||||||
f"Received {cleaned_text}."
|
|
||||||
)
|
|
||||||
return cleaned_text == self.true_val.upper()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
return "boolean_output_parser"
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation]) -> T:
|
|
||||||
"""Parse a list of candidate model Generations into a specific format.
|
|
||||||
|
|
||||||
The return value is parsed from only the first Generation in the result, which
|
|
||||||
is assumed to be the highest-likelihood Generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
result: A list of Generations to be parsed. The Generations are assumed
|
|
||||||
to be different candidate outputs for a single model input.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Structured output.
|
|
||||||
"""
|
|
||||||
return self.parse(result[0].text)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def parse(self, text: str) -> T:
|
|
||||||
"""Parse a single string model output into some structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: String output of language model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Structured output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: rename 'completion' -> 'text'.
|
|
||||||
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
|
||||||
"""Parse the output of an LLM call with the input prompt for context.
|
|
||||||
|
|
||||||
The prompt is largely provided in the event the OutputParser wants
|
|
||||||
to retry or fix the output in some way, and needs information from
|
|
||||||
the prompt to do so.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
completion: String output of language model.
|
|
||||||
prompt: Input PromptValue.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Structured output
|
|
||||||
"""
|
|
||||||
return self.parse(completion)
|
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
|
||||||
"""Instructions on how the LLM output should be formatted."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
"""Return the output parser type for serialization."""
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"_type property is not implemented in class {self.__class__.__name__}."
|
|
||||||
" This is required for serialization."
|
|
||||||
)
|
|
||||||
|
|
||||||
def dict(self, **kwargs: Any) -> Dict:
|
|
||||||
"""Return dictionary representation of output parser."""
|
|
||||||
output_parser_dict = super().dict(**kwargs)
|
|
||||||
output_parser_dict["_type"] = self._type
|
|
||||||
return output_parser_dict
|
|
||||||
|
|
||||||
|
|
||||||
class NoOpOutputParser(BaseOutputParser[str]):
|
|
||||||
"""'No operation' OutputParser that returns the text as is."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_serializable(self) -> bool:
|
|
||||||
"""Whether the class LangChain serializable."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _type(self) -> str:
|
|
||||||
"""Return the output parser type for serialization."""
|
|
||||||
return "default"
|
|
||||||
|
|
||||||
def parse(self, text: str) -> str:
|
|
||||||
"""Returns the input text with no changes."""
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class OutputParserException(ValueError):
|
|
||||||
"""Exception that output parsers should raise to signify a parsing error.
|
|
||||||
|
|
||||||
This exists to differentiate parsing errors from other code or execution errors
|
|
||||||
that also may arise inside the output parser. OutputParserExceptions will be
|
|
||||||
available to catch and handle in ways to fix the parsing error, while other
|
|
||||||
errors will be raised.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
error: The error that's being re-raised or an error message.
|
|
||||||
observation: String explanation of error which can be passed to a
|
|
||||||
model to try and remediate the issue.
|
|
||||||
llm_output: String model output which is error-ing.
|
|
||||||
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
|
||||||
after an OutputParserException has been raised. This gives the underlying
|
|
||||||
model driving the agent the context that the previous output was improperly
|
|
||||||
structured, in the hopes that it will update the output to the correct
|
|
||||||
format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
error: Any,
|
|
||||||
observation: Optional[str] = None,
|
|
||||||
llm_output: Optional[str] = None,
|
|
||||||
send_to_llm: bool = False,
|
|
||||||
):
|
|
||||||
super(OutputParserException, self).__init__(error)
|
|
||||||
if send_to_llm:
|
|
||||||
if observation is None or llm_output is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Arguments 'observation' & 'llm_output'"
|
|
||||||
" are required if 'send_to_llm' is True"
|
|
||||||
)
|
|
||||||
self.observation = observation
|
|
||||||
self.llm_output = llm_output
|
|
||||||
self.send_to_llm = send_to_llm
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDocumentTransformer(ABC):
|
|
||||||
"""Abstract base class for document transformation systems.
|
|
||||||
|
|
||||||
A document transformation system takes a sequence of Documents and returns a
|
|
||||||
sequence of transformed Documents.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
|
||||||
embeddings: Embeddings
|
|
||||||
similarity_fn: Callable = cosine_similarity
|
|
||||||
similarity_threshold: float = 0.95
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def transform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
stateful_documents = get_stateful_documents(documents)
|
|
||||||
embedded_documents = _get_embeddings_from_stateful_docs(
|
|
||||||
self.embeddings, stateful_documents
|
|
||||||
)
|
|
||||||
included_idxs = _filter_similar_embeddings(
|
|
||||||
embedded_documents, self.similarity_fn, self.similarity_threshold
|
|
||||||
)
|
|
||||||
return [stateful_documents[i] for i in sorted(included_idxs)]
|
|
||||||
|
|
||||||
async def atransform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def transform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
"""Transform a list of documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: A sequence of Documents to be transformed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of transformed Documents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def atransform_documents(
|
|
||||||
self, documents: Sequence[Document], **kwargs: Any
|
|
||||||
) -> Sequence[Document]:
|
|
||||||
"""Asynchronously transform a list of documents.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: A sequence of Documents to be transformed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of transformed Documents.
|
|
||||||
"""
|
|
@ -0,0 +1,67 @@
|
|||||||
|
from langchain.schema.agent import AgentAction, AgentFinish
|
||||||
|
from langchain.schema.document import BaseDocumentTransformer, Document
|
||||||
|
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
|
||||||
|
from langchain.schema.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
_message_from_dict,
|
||||||
|
_message_to_dict,
|
||||||
|
get_buffer_string,
|
||||||
|
messages_from_dict,
|
||||||
|
messages_to_dict,
|
||||||
|
)
|
||||||
|
from langchain.schema.output import (
|
||||||
|
ChatGeneration,
|
||||||
|
ChatResult,
|
||||||
|
Generation,
|
||||||
|
LLMResult,
|
||||||
|
RunInfo,
|
||||||
|
)
|
||||||
|
from langchain.schema.output_parser import (
|
||||||
|
BaseLLMOutputParser,
|
||||||
|
BaseOutputParser,
|
||||||
|
NoOpOutputParser,
|
||||||
|
OutputParserException,
|
||||||
|
)
|
||||||
|
from langchain.schema.prompt import PromptValue
|
||||||
|
from langchain.schema.retriever import BaseRetriever
|
||||||
|
|
||||||
|
RUN_KEY = "__run"
|
||||||
|
Memory = BaseMemory
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseMemory",
|
||||||
|
"BaseChatMessageHistory",
|
||||||
|
"AgentFinish",
|
||||||
|
"AgentAction",
|
||||||
|
"Document",
|
||||||
|
"BaseDocumentTransformer",
|
||||||
|
"BaseMessage",
|
||||||
|
"ChatMessage",
|
||||||
|
"FunctionMessage",
|
||||||
|
"HumanMessage",
|
||||||
|
"AIMessage",
|
||||||
|
"SystemMessage",
|
||||||
|
"messages_from_dict",
|
||||||
|
"messages_to_dict",
|
||||||
|
"_message_to_dict",
|
||||||
|
"_message_from_dict",
|
||||||
|
"get_buffer_string",
|
||||||
|
"RunInfo",
|
||||||
|
"LLMResult",
|
||||||
|
"ChatResult",
|
||||||
|
"ChatGeneration",
|
||||||
|
"Generation",
|
||||||
|
"PromptValue",
|
||||||
|
"BaseRetriever",
|
||||||
|
"RUN_KEY",
|
||||||
|
"Memory",
|
||||||
|
"OutputParserException",
|
||||||
|
"NoOpOutputParser",
|
||||||
|
"BaseOutputParser",
|
||||||
|
"BaseLLMOutputParser",
|
||||||
|
]
|
@ -0,0 +1,25 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import NamedTuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentAction:
|
||||||
|
"""A full description of an action for an ActionAgent to execute."""
|
||||||
|
|
||||||
|
tool: str
|
||||||
|
"""The name of the Tool to execute."""
|
||||||
|
tool_input: Union[str, dict]
|
||||||
|
"""The input to pass in to the Tool."""
|
||||||
|
log: str
|
||||||
|
"""Additional information to log about the action."""
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFinish(NamedTuple):
|
||||||
|
"""The final return value of an ActionAgent."""
|
||||||
|
|
||||||
|
return_values: dict
|
||||||
|
"""Dictionary of return values."""
|
||||||
|
log: str
|
||||||
|
"""Additional information to log about the return value"""
|
@ -0,0 +1,82 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
|
||||||
|
|
||||||
|
class Document(Serializable):
|
||||||
|
"""Class for storing a piece of text and associated metadata."""
|
||||||
|
|
||||||
|
page_content: str
|
||||||
|
"""String text."""
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
"""Arbitrary metadata about the page content (e.g., source, relationships to other
|
||||||
|
documents, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDocumentTransformer(ABC):
|
||||||
|
"""Abstract base class for document transformation systems.
|
||||||
|
|
||||||
|
A document transformation system takes a sequence of Documents and returns a
|
||||||
|
sequence of transformed Documents.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
|
||||||
|
embeddings: Embeddings
|
||||||
|
similarity_fn: Callable = cosine_similarity
|
||||||
|
similarity_threshold: float = 0.95
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def transform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
stateful_documents = get_stateful_documents(documents)
|
||||||
|
embedded_documents = _get_embeddings_from_stateful_docs(
|
||||||
|
self.embeddings, stateful_documents
|
||||||
|
)
|
||||||
|
included_idxs = _filter_similar_embeddings(
|
||||||
|
embedded_documents, self.similarity_fn, self.similarity_threshold
|
||||||
|
)
|
||||||
|
return [stateful_documents[i] for i in sorted(included_idxs)]
|
||||||
|
|
||||||
|
async def atransform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""Transform a list of documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: A sequence of Documents to be transformed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of transformed Documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def atransform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""Asynchronously transform a list of documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: A sequence of Documents to be transformed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of transformed Documents.
|
||||||
|
"""
|
@ -0,0 +1,121 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMemory(Serializable, ABC):
|
||||||
|
"""Base abstract class for memory in Chains.
|
||||||
|
|
||||||
|
Memory refers to state in Chains. Memory can be used to store information about
|
||||||
|
past executions of a Chain and inject that information into the inputs of
|
||||||
|
future executions of the Chain. For example, for conversational Chains Memory
|
||||||
|
can be used to store conversations and automatically add them to future model
|
||||||
|
prompts so that the model has the necessary context to respond coherently to
|
||||||
|
the latest input.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class SimpleMemory(BaseMemory):
|
||||||
|
memories: Dict[str, Any] = dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_variables(self) -> List[str]:
|
||||||
|
return list(self.memories.keys())
|
||||||
|
|
||||||
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
return self.memories
|
||||||
|
|
||||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
pass
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def memory_variables(self) -> List[str]:
|
||||||
|
"""The string keys this memory class will add to chain inputs."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
|
||||||
|
"""Save the context of this chain run to memory."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear memory contents."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatMessageHistory(ABC):
|
||||||
|
"""Abstract base class for storing chat message history.
|
||||||
|
|
||||||
|
See `ChatMessageHistory` for default implementation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class FileChatMessageHistory(BaseChatMessageHistory):
|
||||||
|
storage_path: str
|
||||||
|
session_id: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messages(self):
|
||||||
|
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
|
||||||
|
messages = json.loads(f.read())
|
||||||
|
return messages_from_dict(messages)
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
messages = self.messages.append(_message_to_dict(message))
|
||||||
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||||
|
json.dump(f, messages)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||||
|
f.write("[]")
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages: List[BaseMessage]
|
||||||
|
"""A list of Messages stored in-memory."""
|
||||||
|
|
||||||
|
def add_user_message(self, message: str) -> None:
|
||||||
|
"""Convenience method for adding a human message string to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The string contents of a human message.
|
||||||
|
"""
|
||||||
|
self.add_message(HumanMessage(content=message))
|
||||||
|
|
||||||
|
def add_ai_message(self, message: str) -> None:
|
||||||
|
"""Convenience method for adding an AI message string to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The string contents of an AI message.
|
||||||
|
"""
|
||||||
|
self.add_message(AIMessage(content=message))
|
||||||
|
|
||||||
|
# TODO: Make this an abstractmethod.
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a Message object to the store.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: A BaseMessage object to store.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all messages from the store"""
|
@ -0,0 +1,183 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import List, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
|
||||||
|
|
||||||
|
def get_buffer_string(
|
||||||
|
messages: Sequence[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||||
|
) -> str:
|
||||||
|
"""Convert sequence of Messages to strings and concatenate them into one string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Messages to be converted to strings.
|
||||||
|
human_prefix: The prefix to prepend to contents of HumanMessages.
|
||||||
|
ai_prefix: THe prefix to prepend to contents of AIMessages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A single string concatenation of all input messages.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.schema import AIMessage, HumanMessage
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="Hi, how are you?"),
|
||||||
|
AIMessage(content="Good, how are you?"),
|
||||||
|
]
|
||||||
|
get_buffer_string(messages)
|
||||||
|
# -> "Human: Hi, how are you?\nAI: Good, how are you?"
|
||||||
|
"""
|
||||||
|
string_messages = []
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m, HumanMessage):
|
||||||
|
role = human_prefix
|
||||||
|
elif isinstance(m, AIMessage):
|
||||||
|
role = ai_prefix
|
||||||
|
elif isinstance(m, SystemMessage):
|
||||||
|
role = "System"
|
||||||
|
elif isinstance(m, FunctionMessage):
|
||||||
|
role = "Function"
|
||||||
|
elif isinstance(m, ChatMessage):
|
||||||
|
role = m.role
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unsupported message type: {m}")
|
||||||
|
message = f"{role}: {m.content}"
|
||||||
|
if isinstance(m, AIMessage) and "function_call" in m.additional_kwargs:
|
||||||
|
message += f"{m.additional_kwargs['function_call']}"
|
||||||
|
string_messages.append(message)
|
||||||
|
|
||||||
|
return "\n".join(string_messages)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessage(Serializable):
|
||||||
|
"""The base abstract Message class.
|
||||||
|
|
||||||
|
Messages are the inputs and outputs of ChatModels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
"""The string contents of the message."""
|
||||||
|
|
||||||
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
|
"""Any additional information."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the Message, used for serialization."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""Whether this class is LangChain serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class HumanMessage(BaseMessage):
|
||||||
|
"""A Message from a human."""
|
||||||
|
|
||||||
|
example: bool = False
|
||||||
|
"""Whether this Message is being passed in to the model as part of an example
|
||||||
|
conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "human"
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessage(BaseMessage):
|
||||||
|
"""A Message from an AI."""
|
||||||
|
|
||||||
|
example: bool = False
|
||||||
|
"""Whether this Message is being passed in to the model as part of an example
|
||||||
|
conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "ai"
|
||||||
|
|
||||||
|
|
||||||
|
class SystemMessage(BaseMessage):
|
||||||
|
"""A Message for priming AI behavior, usually passed in as the first of a sequence
|
||||||
|
of input messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "system"
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMessage(BaseMessage):
|
||||||
|
"""A Message for passing the result of executing a function back to a model."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""The name of the function that was executed."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "function"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseMessage):
|
||||||
|
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
|
||||||
|
|
||||||
|
role: str
|
||||||
|
"""The speaker / role of the Message."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "chat"
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
return {"type": message.type, "data": message.dict()}
|
||||||
|
|
||||||
|
|
||||||
|
def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
|
||||||
|
"""Convert a sequence of Messages to a list of dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages (as BaseMessages) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages as dicts.
|
||||||
|
"""
|
||||||
|
return [_message_to_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def _message_from_dict(message: dict) -> BaseMessage:
|
||||||
|
_type = message["type"]
|
||||||
|
if _type == "human":
|
||||||
|
return HumanMessage(**message["data"])
|
||||||
|
elif _type == "ai":
|
||||||
|
return AIMessage(**message["data"])
|
||||||
|
elif _type == "system":
|
||||||
|
return SystemMessage(**message["data"])
|
||||||
|
elif _type == "chat":
|
||||||
|
return ChatMessage(**message["data"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unexpected type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||||
|
"""Convert a sequence of messages from dicts to Message objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Sequence of messages (as dicts) to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages (BaseMessages).
|
||||||
|
"""
|
||||||
|
return [_message_from_dict(m) for m in messages]
|
@ -0,0 +1,118 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
|
class Generation(Serializable):
|
||||||
|
"""A single text generation output."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
"""Generated text output."""
|
||||||
|
|
||||||
|
generation_info: Optional[Dict[str, Any]] = None
|
||||||
|
"""Raw response from the provider. May include things like the
|
||||||
|
reason for finishing or token log probabilities.
|
||||||
|
"""
|
||||||
|
# TODO: add log probs as separate attribute
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""Whether this class is LangChain serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGeneration(Generation):
|
||||||
|
"""A single chat generation output."""
|
||||||
|
|
||||||
|
text: str = ""
|
||||||
|
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||||
|
message: BaseMessage
|
||||||
|
"""The message output by the chat model."""
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Set the text attribute to be the contents of the message."""
|
||||||
|
values["text"] = values["message"].content
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class RunInfo(BaseModel):
|
||||||
|
"""Class that contains metadata for a single execution of a Chain or model."""
|
||||||
|
|
||||||
|
run_id: UUID
|
||||||
|
"""A unique identifier for the model or chain run."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResult(BaseModel):
|
||||||
|
"""Class that contains all results for a single chat model call."""
|
||||||
|
|
||||||
|
generations: List[ChatGeneration]
|
||||||
|
"""List of the chat generations. This is a List because an input can have multiple
|
||||||
|
candidate generations.
|
||||||
|
"""
|
||||||
|
llm_output: Optional[dict] = None
|
||||||
|
"""For arbitrary LLM provider specific output."""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMResult(BaseModel):
|
||||||
|
"""Class that contains all results for a batched LLM call."""
|
||||||
|
|
||||||
|
generations: List[List[Generation]]
|
||||||
|
"""List of generated outputs. This is a List[List[]] because
|
||||||
|
each input could have multiple candidate generations."""
|
||||||
|
llm_output: Optional[dict] = None
|
||||||
|
"""Arbitrary LLM provider-specific output."""
|
||||||
|
run: Optional[List[RunInfo]] = None
|
||||||
|
"""List of metadata info for model call for each input."""
|
||||||
|
|
||||||
|
def flatten(self) -> List[LLMResult]:
|
||||||
|
"""Flatten generations into a single list.
|
||||||
|
|
||||||
|
Unpack List[List[Generation]] -> List[LLMResult] where each returned LLMResult
|
||||||
|
contains only a single Generation. If token usage information is available,
|
||||||
|
it is kept only for the LLMResult corresponding to the top-choice
|
||||||
|
Generation, to avoid over-counting of token usage downstream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of LLMResults where each returned LLMResult contains a single
|
||||||
|
Generation.
|
||||||
|
"""
|
||||||
|
llm_results = []
|
||||||
|
for i, gen_list in enumerate(self.generations):
|
||||||
|
# Avoid double counting tokens in OpenAICallback
|
||||||
|
if i == 0:
|
||||||
|
llm_results.append(
|
||||||
|
LLMResult(
|
||||||
|
generations=[gen_list],
|
||||||
|
llm_output=self.llm_output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.llm_output is not None:
|
||||||
|
llm_output = deepcopy(self.llm_output)
|
||||||
|
llm_output["token_usage"] = dict()
|
||||||
|
else:
|
||||||
|
llm_output = None
|
||||||
|
llm_results.append(
|
||||||
|
LLMResult(
|
||||||
|
generations=[gen_list],
|
||||||
|
llm_output=llm_output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return llm_results
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Check for LLMResult equality by ignoring any metadata related to runs."""
|
||||||
|
if not isinstance(other, LLMResult):
|
||||||
|
return NotImplemented
|
||||||
|
return (
|
||||||
|
self.generations == other.generations
|
||||||
|
and self.llm_output == other.llm_output
|
||||||
|
)
|
@ -0,0 +1,172 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Generic, List, Optional, TypeVar
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.output import Generation
|
||||||
|
from langchain.schema.prompt import PromptValue
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
|
||||||
|
"""Abstract base class for parsing the outputs of a model."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_result(self, result: List[Generation]) -> T:
|
||||||
|
"""Parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: A list of Generations to be parsed. The Generations are assumed
|
||||||
|
to be different candidate outputs for a single model input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Structured output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
|
||||||
|
"""Class to parse the output of an LLM call.
|
||||||
|
|
||||||
|
Output parsers help structure language model responses.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class BooleanOutputParser(BaseOutputParser[bool]):
|
||||||
|
true_val: str = "YES"
|
||||||
|
false_val: str = "NO"
|
||||||
|
|
||||||
|
def parse(self, text: str) -> bool:
|
||||||
|
cleaned_text = text.strip().upper()
|
||||||
|
if cleaned_text not in (self.true_val.upper(), self.false_val.upper()):
|
||||||
|
raise OutputParserException(
|
||||||
|
f"BooleanOutputParser expected output value to either be "
|
||||||
|
f"{self.true_val} or {self.false_val} (case-insensitive). "
|
||||||
|
f"Received {cleaned_text}."
|
||||||
|
)
|
||||||
|
return cleaned_text == self.true_val.upper()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "boolean_output_parser"
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation]) -> T:
|
||||||
|
"""Parse a list of candidate model Generations into a specific format.
|
||||||
|
|
||||||
|
The return value is parsed from only the first Generation in the result, which
|
||||||
|
is assumed to be the highest-likelihood Generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: A list of Generations to be parsed. The Generations are assumed
|
||||||
|
to be different candidate outputs for a single model input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Structured output.
|
||||||
|
"""
|
||||||
|
return self.parse(result[0].text)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse(self, text: str) -> T:
|
||||||
|
"""Parse a single string model output into some structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String output of language model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Structured output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: rename 'completion' -> 'text'.
|
||||||
|
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
|
||||||
|
"""Parse the output of an LLM call with the input prompt for context.
|
||||||
|
|
||||||
|
The prompt is largely provided in the event the OutputParser wants
|
||||||
|
to retry or fix the output in some way, and needs information from
|
||||||
|
the prompt to do so.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
completion: String output of language model.
|
||||||
|
prompt: Input PromptValue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Structured output
|
||||||
|
"""
|
||||||
|
return self.parse(completion)
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
"""Instructions on how the LLM output should be formatted."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""Return the output parser type for serialization."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"_type property is not implemented in class {self.__class__.__name__}."
|
||||||
|
" This is required for serialization."
|
||||||
|
)
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return dictionary representation of output parser."""
|
||||||
|
output_parser_dict = super().dict(**kwargs)
|
||||||
|
output_parser_dict["_type"] = self._type
|
||||||
|
return output_parser_dict
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpOutputParser(BaseOutputParser[str]):
|
||||||
|
"""'No operation' OutputParser that returns the text as is."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""Whether the class LangChain serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""Return the output parser type for serialization."""
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
def parse(self, text: str) -> str:
|
||||||
|
"""Returns the input text with no changes."""
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class OutputParserException(ValueError):
|
||||||
|
"""Exception that output parsers should raise to signify a parsing error.
|
||||||
|
|
||||||
|
This exists to differentiate parsing errors from other code or execution errors
|
||||||
|
that also may arise inside the output parser. OutputParserExceptions will be
|
||||||
|
available to catch and handle in ways to fix the parsing error, while other
|
||||||
|
errors will be raised.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The error that's being re-raised or an error message.
|
||||||
|
observation: String explanation of error which can be passed to a
|
||||||
|
model to try and remediate the issue.
|
||||||
|
llm_output: String model output which is error-ing.
|
||||||
|
send_to_llm: Whether to send the observation and llm_output back to an Agent
|
||||||
|
after an OutputParserException has been raised. This gives the underlying
|
||||||
|
model driving the agent the context that the previous output was improperly
|
||||||
|
structured, in the hopes that it will update the output to the correct
|
||||||
|
format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error: Any,
|
||||||
|
observation: Optional[str] = None,
|
||||||
|
llm_output: Optional[str] = None,
|
||||||
|
send_to_llm: bool = False,
|
||||||
|
):
|
||||||
|
super(OutputParserException, self).__init__(error)
|
||||||
|
if send_to_llm:
|
||||||
|
if observation is None or llm_output is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Arguments 'observation' & 'llm_output'"
|
||||||
|
" are required if 'send_to_llm' is True"
|
||||||
|
)
|
||||||
|
self.observation = observation
|
||||||
|
self.llm_output = llm_output
|
||||||
|
self.send_to_llm = send_to_llm
|
@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.schema.messages import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
|
class PromptValue(Serializable, ABC):
|
||||||
|
"""Base abstract class for inputs to any language model.
|
||||||
|
|
||||||
|
PromptValues can be converted to both LLM (pure text-generation) inputs and
|
||||||
|
ChatModel inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Return prompt value as string."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
|
"""Return prompt as a list of Messages."""
|
@ -0,0 +1,191 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from inspect import signature
|
||||||
|
from typing import TYPE_CHECKING, Any, List
|
||||||
|
|
||||||
|
from langchain.schema.document import Document
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForRetrieverRun,
|
||||||
|
CallbackManagerForRetrieverRun,
|
||||||
|
Callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRetriever(ABC):
|
||||||
|
"""Abstract base class for a Document retrieval system.
|
||||||
|
|
||||||
|
A retrieval system is defined as something that can take string queries and return
|
||||||
|
the most 'relevant' Documents from some source.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class TFIDFRetriever(BaseRetriever, BaseModel):
|
||||||
|
vectorizer: Any
|
||||||
|
docs: List[Document]
|
||||||
|
tfidf_array: Any
|
||||||
|
k: int = 4
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
||||||
|
query_vec = self.vectorizer.transform([query])
|
||||||
|
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
||||||
|
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
||||||
|
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
_new_arg_supported: bool = False
|
||||||
|
_expects_other_args: bool = False
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
# Version upgrade for old retrievers that implemented the public
|
||||||
|
# methods directly.
|
||||||
|
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
||||||
|
warnings.warn(
|
||||||
|
"Retrievers must implement abstract `_get_relevant_documents` method"
|
||||||
|
" instead of `get_relevant_documents`",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
swap = cls.get_relevant_documents
|
||||||
|
cls.get_relevant_documents = ( # type: ignore[assignment]
|
||||||
|
BaseRetriever.get_relevant_documents
|
||||||
|
)
|
||||||
|
cls._get_relevant_documents = swap # type: ignore[assignment]
|
||||||
|
if (
|
||||||
|
hasattr(cls, "aget_relevant_documents")
|
||||||
|
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
||||||
|
" instead of `aget_relevant_documents`",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
aswap = cls.aget_relevant_documents
|
||||||
|
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
||||||
|
BaseRetriever.aget_relevant_documents
|
||||||
|
)
|
||||||
|
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
||||||
|
parameters = signature(cls._get_relevant_documents).parameters
|
||||||
|
cls._new_arg_supported = parameters.get("run_manager") is not None
|
||||||
|
# If a V1 retriever broke the interface and expects additional arguments
|
||||||
|
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Get documents relevant to a query.
|
||||||
|
Args:
|
||||||
|
query: String to find relevant documents for.
|
||||||
|
run_manager: The callbacks handler to use.
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _aget_relevant_documents(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
*,
|
||||||
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Asynchronously get documents relevant to a query.
|
||||||
|
Args:
|
||||||
|
query: string to find relevant documents for
|
||||||
|
run_manager: The callbacks handler to use
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_relevant_documents(
|
||||||
|
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Retrieve documents relevant to a query.
|
||||||
|
Args:
|
||||||
|
query: String to find relevant documents for.
|
||||||
|
callbacks: Callback manager or list of callbacks.
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
callback_manager = CallbackManager.configure(
|
||||||
|
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||||
|
)
|
||||||
|
run_manager = callback_manager.on_retriever_start(
|
||||||
|
query,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if self._new_arg_supported:
|
||||||
|
result = self._get_relevant_documents(
|
||||||
|
query, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
elif self._expects_other_args:
|
||||||
|
result = self._get_relevant_documents(query, **kwargs)
|
||||||
|
else:
|
||||||
|
result = self._get_relevant_documents(query) # type: ignore[call-arg]
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_retriever_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
run_manager.on_retriever_end(
|
||||||
|
result,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def aget_relevant_documents(
|
||||||
|
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Asynchronously get documents relevant to a query.
|
||||||
|
Args:
|
||||||
|
query: string to find relevant documents for
|
||||||
|
callbacks: Callback manager or list of callbacks
|
||||||
|
Returns:
|
||||||
|
List of relevant documents
|
||||||
|
"""
|
||||||
|
from langchain.callbacks.manager import AsyncCallbackManager
|
||||||
|
|
||||||
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
|
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||||
|
)
|
||||||
|
run_manager = await callback_manager.on_retriever_start(
|
||||||
|
query,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if self._new_arg_supported:
|
||||||
|
result = await self._aget_relevant_documents(
|
||||||
|
query, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
elif self._expects_other_args:
|
||||||
|
result = await self._aget_relevant_documents(query, **kwargs)
|
||||||
|
else:
|
||||||
|
result = await self._aget_relevant_documents(
|
||||||
|
query, # type: ignore[call-arg]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await run_manager.on_retriever_error(e)
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
await run_manager.on_retriever_end(
|
||||||
|
result,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return result
|
Loading…
Reference in New Issue