langchain/libs/core/langchain_core/retrievers.py

390 lines
14 KiB
Python
Raw Normal View History

"""**Retriever** class returns Documents given a text **query**.
It is more general than a vector store. A retriever does not need to be able to
store documents, only to return (or retrieve) it. Vector stores can be used as
the backbone of a retriever, but there are other types of retrievers as well.
**Class hierarchy:**
.. code-block::
BaseRetriever --> <name>Retriever # Examples: ArxivRetriever, MergerRetriever
**Main helpers:**
.. code-block::
RetrieverInput, RetrieverOutput, RetrieverLike, RetrieverOutputLike,
Document, Serializable, Callbacks,
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
"""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableSerializable,
ensure_config,
)
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Callbacks,
)
Add create_conv_retrieval_chain func (#15084) ``` +----------+ | MapInput | **+----------+**** **** **** **** *** ** **** +------------------------------------+ ** | Lambda(itemgetter('chat_history')) | * +------------------------------------+ * * * * * * * +---------------------------+ +--------------------------------+ | Lambda(_get_chat_history) | | Lambda(itemgetter('question')) | +---------------------------+ +--------------------------------+ * * * * * * +----------------------------+ +------------------------+ | ContextSet('chat_history') | | ContextSet('question') | +----------------------------+ +------------------------+ **** **** **** **** ** ** +-----------+ | MapOutput | +-----------+ * * * +----------------+ | PromptTemplate | +----------------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +-----------------+ | StrOutputParser | +-----------------+ * * * +----------------------------+ | ContextSet('new_question') | +----------------------------+ * * * +---------------------+ | SequentialRetriever | +---------------------+ * * * +------------------------------------+ | Lambda(_reduce_tokens_below_limit) | +------------------------------------+ * * * +-------------------------------+ | ContextSet('input_documents') | +-------------------------------+ * * * +----------+ ***| MapInput |**** ******* +----------+ ******** ******** * ******* ******* * ******** **** * **** +-------------------------------+ +----------------------------+ +----------------------------+ | ContextGet('input_documents') | | ContextGet('chat_history') | | ContextGet('new_question') | +-------------------------------+**** +----------------------------+ +----------------------------+ ********* * ******* ******** * ****** ***** * **** +-----------+ | MapOutput | +-----------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +----------+ ***| MapInput |*** ******** +----------+ ****** ******* * ***** ******** * ****** **** * *** +-------------------------------+ +----------------------------+ +-------------+ | ContextGet('input_documents') | | ContextGet('new_question') | **| Passthrough | +-------------------------------+ +----------------------------+ ******* +-------------+ ******* * ****** ****** * ******* **** * **** +-----------+ | MapOutput | +-----------+ ``` --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2023-12-27 01:28:10 +00:00
RetrieverInput = str
RetrieverOutput = List[Document]
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
Add create_conv_retrieval_chain func (#15084) ``` +----------+ | MapInput | **+----------+**** **** **** **** *** ** **** +------------------------------------+ ** | Lambda(itemgetter('chat_history')) | * +------------------------------------+ * * * * * * * +---------------------------+ +--------------------------------+ | Lambda(_get_chat_history) | | Lambda(itemgetter('question')) | +---------------------------+ +--------------------------------+ * * * * * * +----------------------------+ +------------------------+ | ContextSet('chat_history') | | ContextSet('question') | +----------------------------+ +------------------------+ **** **** **** **** ** ** +-----------+ | MapOutput | +-----------+ * * * +----------------+ | PromptTemplate | +----------------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +-----------------+ | StrOutputParser | +-----------------+ * * * +----------------------------+ | ContextSet('new_question') | +----------------------------+ * * * +---------------------+ | SequentialRetriever | +---------------------+ * * * +------------------------------------+ | Lambda(_reduce_tokens_below_limit) | +------------------------------------+ * * * +-------------------------------+ | ContextSet('input_documents') | +-------------------------------+ * * * +----------+ ***| MapInput |**** ******* +----------+ ******** ******** * ******* ******* * ******** **** * **** +-------------------------------+ +----------------------------+ +----------------------------+ | ContextGet('input_documents') | | ContextGet('chat_history') | | ContextGet('new_question') | +-------------------------------+**** +----------------------------+ +----------------------------+ ********* * ******* ******** * ****** ***** * **** +-----------+ | MapOutput | +-----------+ * * * +-------------+ | FakeListLLM | +-------------+ * * * +----------+ ***| MapInput |*** ******** +----------+ ****** ******* * ***** ******** * ****** **** * *** +-------------------------------+ +----------------------------+ +-------------+ | ContextGet('input_documents') | | ContextGet('new_question') | **| Passthrough | +-------------------------------+ +----------------------------+ ******* +-------------+ ******* * ****** ****** * ******* **** * **** +-----------+ | MapOutput | +-----------+ ``` --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
2023-12-27 01:28:10 +00:00
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
the most 'relevant' Documents from some source.
Usage:
A retriever follows the standard Runnable interface, and should be used
via the standard runnable methods of `invoke`, `ainvoke`, `batch`, `abatch`.
Implementation:
When implementing a custom retriever, the class should implement
the `_get_relevant_documents` method to define the logic for retrieving documents.
Optionally, an async native implementations can be provided by overriding the
`_aget_relevant_documents` method.
Example: A retriever that returns the first 5 documents from a list of documents
.. code-block:: python
from langchain_core import Document, BaseRetriever
from typing import List
class SimpleRetriever(BaseRetriever):
docs: List[Document]
k: int = 5
def _get_relevant_documents(self, query: str) -> List[Document]:
\"\"\"Return the first k documents from the list of documents\"\"\"
return self.docs[:self.k]
async def _aget_relevant_documents(self, query: str) -> List[Document]:
\"\"\"(Optional) async native implementation.\"\"\"
return self.docs[:self.k]
Example: A simple retriever based on a scitkit learn vectorizer
.. code-block:: python
from sklearn.metrics.pairwise import cosine_similarity
class TFIDFRetriever(BaseRetriever, BaseModel):
vectorizer: Any
docs: List[Document]
tfidf_array: Any
k: int = 4
class Config:
arbitrary_types_allowed = True
def _get_relevant_documents(self, query: str) -> List[Document]:
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
query_vec = self.vectorizer.transform([query])
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
_new_arg_supported: bool = False
_expects_other_args: bool = False
tags: Optional[List[str]] = None
"""Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
"""Invoke the retriever to get relevant documents.
Main entry point for synchronous retriever invocations.
Args:
input: The query string
config: Configuration for the retriever
**kwargs: Additional arguments to pass to the retriever
Returns:
List of relevant documents
Examples:
.. code-block:: python
retriever.invoke("query")
"""
config = ensure_config(config)
return self.get_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
async def ainvoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously invoke the retriever to get relevant documents.
Main entry point for asynchronous retriever invocations.
Args:
input: The query string
config: Configuration for the retriever
**kwargs: Additional arguments to pass to the retriever
Returns:
List of relevant documents
Examples:
.. code-block:: python
await retriever.ainvoke("query")
"""
config = ensure_config(config)
return await self.aget_relevant_documents(
input,
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_name=config.get("run_name"),
**kwargs,
)
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
return await run_in_executor(
None,
self._get_relevant_documents,
query,
run_manager=run_manager.get_sync(),
)
def get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Retrieve documents relevant to a query.
Users should favor using `.invoke` or `.batch` rather than
`get_relevant_documents directly`.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run.
Returns:
List of relevant documents
"""
from langchain_core.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = self._get_relevant_documents(query, **_kwargs)
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
)
return result
async def aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_name: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Users should favor using `.ainvoke` or `.abatch` rather than
`aget_relevant_documents directly`.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
run_name: Optional name for the run.
Returns:
List of relevant documents
"""
from langchain_core.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
name=run_name,
run_id=kwargs.pop("run_id", None),
)
try:
_kwargs = kwargs if self._expects_other_args else {}
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **_kwargs
)
else:
result = await self._aget_relevant_documents(query, **_kwargs)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
)
return result