mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
780337488e
The root run id (~trace id's) is useful for assigning feedback, but the current recommended approach is to use callbacks to retrieve it, which has some drawbacks: 1. Doesn't work for streaming until after the first event 2. Doesn't let you call other endpoints with the same trace ID in parallel (since you have to wait until the call is completed/started to use This PR lets you provide = "run_id" in the runnable config. Couple considerations: 1. For batch calls, we split the trace up into separate trees (to permit better rendering). We keep the provided run ID for the first one and generate a unique one for other elements of the batch. 2. For nested calls, the provided ID is ONLY used on the top root/trace. ### Example Usage ``` chain.invoke("foo", {"run_id": uuid.uuid4()}) ```
308 lines
11 KiB
Python
308 lines
11 KiB
Python
"""**Retriever** class returns Documents given a text **query**.
|
|
|
|
It is more general than a vector store. A retriever does not need to be able to
|
|
store documents, only to return (or retrieve) it. Vector stores can be used as
|
|
the backbone of a retriever, but there are other types of retrievers as well.
|
|
|
|
**Class hierarchy:**
|
|
|
|
.. code-block::
|
|
|
|
BaseRetriever --> <name>Retriever # Examples: ArxivRetriever, MergerRetriever
|
|
|
|
**Main helpers:**
|
|
|
|
.. code-block::
|
|
|
|
RetrieverInput, RetrieverOutput, RetrieverLike, RetrieverOutputLike,
|
|
Document, Serializable, Callbacks,
|
|
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from inspect import signature
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
|
|
from langchain_core.documents import Document
|
|
from langchain_core.load.dump import dumpd
|
|
from langchain_core.runnables import (
|
|
Runnable,
|
|
RunnableConfig,
|
|
RunnableSerializable,
|
|
ensure_config,
|
|
)
|
|
from langchain_core.runnables.config import run_in_executor
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain_core.callbacks.manager import (
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
CallbackManagerForRetrieverRun,
|
|
Callbacks,
|
|
)
|
|
|
|
RetrieverInput = str
|
|
RetrieverOutput = List[Document]
|
|
RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
|
|
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
|
|
|
|
|
|
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
|
"""Abstract base class for a Document retrieval system.
|
|
|
|
A retrieval system is defined as something that can take string queries and return
|
|
the most 'relevant' Documents from some source.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
class TFIDFRetriever(BaseRetriever, BaseModel):
|
|
vectorizer: Any
|
|
docs: List[Document]
|
|
tfidf_array: Any
|
|
k: int = 4
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)
|
|
query_vec = self.vectorizer.transform([query])
|
|
# Op -- (n_docs,1) -- Cosine Sim with each doc
|
|
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
|
|
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
|
|
""" # noqa: E501
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
_new_arg_supported: bool = False
|
|
_expects_other_args: bool = False
|
|
tags: Optional[List[str]] = None
|
|
"""Optional list of tags associated with the retriever. Defaults to None
|
|
These tags will be associated with each call to this retriever,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
You can use these to eg identify a specific instance of a retriever with its
|
|
use case.
|
|
"""
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
"""Optional metadata associated with the retriever. Defaults to None
|
|
This metadata will be associated with each call to this retriever,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
You can use these to eg identify a specific instance of a retriever with its
|
|
use case.
|
|
"""
|
|
|
|
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
super().__init_subclass__(**kwargs)
|
|
# Version upgrade for old retrievers that implemented the public
|
|
# methods directly.
|
|
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
|
|
warnings.warn(
|
|
"Retrievers must implement abstract `_get_relevant_documents` method"
|
|
" instead of `get_relevant_documents`",
|
|
DeprecationWarning,
|
|
)
|
|
swap = cls.get_relevant_documents
|
|
cls.get_relevant_documents = ( # type: ignore[assignment]
|
|
BaseRetriever.get_relevant_documents
|
|
)
|
|
cls._get_relevant_documents = swap # type: ignore[assignment]
|
|
if (
|
|
hasattr(cls, "aget_relevant_documents")
|
|
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
|
|
):
|
|
warnings.warn(
|
|
"Retrievers must implement abstract `_aget_relevant_documents` method"
|
|
" instead of `aget_relevant_documents`",
|
|
DeprecationWarning,
|
|
)
|
|
aswap = cls.aget_relevant_documents
|
|
cls.aget_relevant_documents = ( # type: ignore[assignment]
|
|
BaseRetriever.aget_relevant_documents
|
|
)
|
|
cls._aget_relevant_documents = aswap # type: ignore[assignment]
|
|
parameters = signature(cls._get_relevant_documents).parameters
|
|
cls._new_arg_supported = parameters.get("run_manager") is not None
|
|
# If a V1 retriever broke the interface and expects additional arguments
|
|
cls._expects_other_args = (
|
|
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
|
)
|
|
|
|
def invoke(
|
|
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
|
) -> List[Document]:
|
|
config = ensure_config(config)
|
|
return self.get_relevant_documents(
|
|
input,
|
|
callbacks=config.get("callbacks"),
|
|
tags=config.get("tags"),
|
|
metadata=config.get("metadata"),
|
|
run_name=config.get("run_name"),
|
|
**kwargs,
|
|
)
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: str,
|
|
config: Optional[RunnableConfig] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
config = ensure_config(config)
|
|
return await self.aget_relevant_documents(
|
|
input,
|
|
callbacks=config.get("callbacks"),
|
|
tags=config.get("tags"),
|
|
metadata=config.get("metadata"),
|
|
run_name=config.get("run_name"),
|
|
**kwargs,
|
|
)
|
|
|
|
@abstractmethod
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
"""Get documents relevant to a query.
|
|
Args:
|
|
query: String to find relevant documents for
|
|
run_manager: The callbacks handler to use
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
|
|
async def _aget_relevant_documents(
|
|
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
"""Asynchronously get documents relevant to a query.
|
|
Args:
|
|
query: String to find relevant documents for
|
|
run_manager: The callbacks handler to use
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
return await run_in_executor(
|
|
None,
|
|
self._get_relevant_documents,
|
|
query,
|
|
run_manager=run_manager.get_sync(),
|
|
)
|
|
|
|
def get_relevant_documents(
|
|
self,
|
|
query: str,
|
|
*,
|
|
callbacks: Callbacks = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
run_name: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Retrieve documents relevant to a query.
|
|
Args:
|
|
query: string to find relevant documents for
|
|
callbacks: Callback manager or list of callbacks
|
|
tags: Optional list of tags associated with the retriever. Defaults to None
|
|
These tags will be associated with each call to this retriever,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
metadata: Optional metadata associated with the retriever. Defaults to None
|
|
This metadata will be associated with each call to this retriever,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
from langchain_core.callbacks.manager import CallbackManager
|
|
|
|
callback_manager = CallbackManager.configure(
|
|
callbacks,
|
|
None,
|
|
verbose=kwargs.get("verbose", False),
|
|
inheritable_tags=tags,
|
|
local_tags=self.tags,
|
|
inheritable_metadata=metadata,
|
|
local_metadata=self.metadata,
|
|
)
|
|
run_manager = callback_manager.on_retriever_start(
|
|
dumpd(self),
|
|
query,
|
|
name=run_name,
|
|
run_id=kwargs.pop("run_id", None),
|
|
)
|
|
try:
|
|
_kwargs = kwargs if self._expects_other_args else {}
|
|
if self._new_arg_supported:
|
|
result = self._get_relevant_documents(
|
|
query, run_manager=run_manager, **_kwargs
|
|
)
|
|
else:
|
|
result = self._get_relevant_documents(query, **_kwargs)
|
|
except Exception as e:
|
|
run_manager.on_retriever_error(e)
|
|
raise e
|
|
else:
|
|
run_manager.on_retriever_end(
|
|
result,
|
|
)
|
|
return result
|
|
|
|
async def aget_relevant_documents(
|
|
self,
|
|
query: str,
|
|
*,
|
|
callbacks: Callbacks = None,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
run_name: Optional[str] = None,
|
|
**kwargs: Any,
|
|
) -> List[Document]:
|
|
"""Asynchronously get documents relevant to a query.
|
|
Args:
|
|
query: string to find relevant documents for
|
|
callbacks: Callback manager or list of callbacks
|
|
tags: Optional list of tags associated with the retriever. Defaults to None
|
|
These tags will be associated with each call to this retriever,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
metadata: Optional metadata associated with the retriever. Defaults to None
|
|
This metadata will be associated with each call to this retriever,
|
|
and passed as arguments to the handlers defined in `callbacks`.
|
|
Returns:
|
|
List of relevant documents
|
|
"""
|
|
from langchain_core.callbacks.manager import AsyncCallbackManager
|
|
|
|
callback_manager = AsyncCallbackManager.configure(
|
|
callbacks,
|
|
None,
|
|
verbose=kwargs.get("verbose", False),
|
|
inheritable_tags=tags,
|
|
local_tags=self.tags,
|
|
inheritable_metadata=metadata,
|
|
local_metadata=self.metadata,
|
|
)
|
|
run_manager = await callback_manager.on_retriever_start(
|
|
dumpd(self),
|
|
query,
|
|
name=run_name,
|
|
run_id=kwargs.pop("run_id", None),
|
|
)
|
|
try:
|
|
_kwargs = kwargs if self._expects_other_args else {}
|
|
if self._new_arg_supported:
|
|
result = await self._aget_relevant_documents(
|
|
query, run_manager=run_manager, **_kwargs
|
|
)
|
|
else:
|
|
result = await self._aget_relevant_documents(query, **_kwargs)
|
|
except Exception as e:
|
|
await run_manager.on_retriever_error(e)
|
|
raise e
|
|
else:
|
|
await run_manager.on_retriever_end(
|
|
result,
|
|
)
|
|
return result
|