Add New Retriever Interface with Callbacks (#5962)

Handle the new retriever events in a way that (I think) is entirely
backwards compatible? Needs more testing for some of the chain changes
and all.

This creates an entire new run type, however. We could also just treat
this as an event within a chain run presumably (same with memory)

Adds a subclass initializer that upgrades old retriever implementations
to the new schema, along with tests to ensure they work.

First commit doesn't upgrade any of our retriever implementations (to
show that we can pass the tests along with additional ones testing the
upgrade logic).

Second commit upgrades the known universe of retrievers in langchain.

- [X] Add callback handling methods for retriever start/end/error (open
to renaming to 'retrieval' if you want that)
- [X] Update BaseRetriever schema to support callbacks
- [X] Tests for upgrading old "v1" retrievers for backwards
compatibility
- [X] Update existing retriever implementations to implement the new
interface
- [X] Update calls within chains to .{a]get_relevant_documents to pass
the child callback manager
- [X] Update the notebooks/docs to reflect the new interface
- [X] Test notebooks thoroughly


Not handled:
- Memory pass throughs: retrieval memory doesn't have a parent callback
manager passed through the method

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
pull/7005/head
Zander Chase 1 year ago committed by GitHub
parent a5b206caf3
commit b0859c9b18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,11 +71,13 @@
"import numpy as np\n",
"\n",
"from langchain.schema import BaseRetriever\n",
"from langchain.callbacks.manager import AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun\n",
"from langchain.utilities import GoogleSerperAPIWrapper\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.llms import OpenAI\n",
"from langchain.schema import Document"
"from langchain.schema import Document\n",
"from typing import Any"
]
},
{
@ -97,11 +99,16 @@
" def __init__(self, search):\n",
" self.search = search\n",
"\n",
" def get_relevant_documents(self, query: str):\n",
" def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any) -> List[Document]:\n",
" return [Document(page_content=self.search.run(query))]\n",
"\n",
" async def aget_relevant_documents(self, query: str):\n",
" raise NotImplemented\n",
" async def _aget_relevant_documents(self,\n",
" query: str,\n",
" *,\n",
" run_manager: AsyncCallbackManagerForRetrieverRun,\n",
" **kwargs: Any,\n",
" ) -> List[Document]:\n",
" raise NotImplementedError()\n",
"\n",
"\n",
"retriever = SerperSearchRetriever(GoogleSerperAPIWrapper())"

@ -43,7 +43,6 @@
"import openai\n",
"from dotenv import load_dotenv\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.schema import BaseRetriever\n",
"from langchain.vectorstores.azuresearch import AzureSearch"
]
},

@ -1,24 +1,40 @@
The `BaseRetriever` class in LangChain is as follows:
The public API of the `BaseRetriever` class in LangChain is as follows:
```python
from abc import ABC, abstractmethod
from typing import List
from typing import Any, List
from langchain.schema import Document
from langchain.callbacks.manager import Callbacks
class BaseRetriever(ABC):
@abstractmethod
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get texts relevant for a query.
...
def get_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: string to find relevant texts for
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
...
async def aget_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
...
```
It's that simple! The `get_relevant_documents` method can be implemented however you see fit.
It's that simple! You can call `get_relevant_documents` or the async `get_relevant_documents` methods to retrieve documents relevant to a query, where "relevance" is defined by
the specific retriever object you are calling.
Of course, we also help construct what we think useful Retrievers are. The main type of Retriever that we focus on is a Vectorstore retriever. We will focus on that for the rest of this guide.

@ -0,0 +1,162 @@
# Implement a Custom Retriever
In this walkthrough, you will implement a simple custom retriever in LangChain using a simple dot product distance lookup.
All retrievers inherit from the `BaseRetriever` class and override the following abstract methods:
```python
from abc import ABC, abstractmethod
from typing import Any, List
from langchain.schema import Document
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
class BaseRetriever(ABC):
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
@abstractmethod
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
```
The `_get_relevant_documents` and async `_get_relevant_documents` methods can be implemented however you see fit. The `run_manager` is useful if your retriever calls other traceable LangChain primitives like LLMs, chains, or tools.
Below, implement an example that fetches the most similar documents from a list of documents using a numpy array of embeddings.
```python
from typing import Any, List, Optional
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
class NumpyRetriever(BaseRetriever):
"""Retrieves documents from a numpy array."""
def __init__(
self,
texts: List[str],
vectors: np.ndarray,
embeddings: Optional[Embeddings] = None,
num_to_return: int = 1,
) -> None:
super().__init__()
self.embeddings = embeddings or OpenAIEmbeddings()
self.texts = texts
self.vectors = vectors
self.num_to_return = num_to_return
@classmethod
def from_texts(
cls,
texts: List[str],
embeddings: Optional[Embeddings] = None,
**kwargs: Any,
) -> "NumpyRetriever":
embeddings = embeddings or OpenAIEmbeddings()
vectors = np.array(embeddings.embed_documents(texts))
return cls(texts, vectors, embeddings)
def _get_relevant_documents_from_query_vector(
self, vector_query: np.ndarray
) -> List[Document]:
dot_product = np.dot(self.vectors, vector_query)
# Get the indices of the min 5 documents
indices = np.argpartition(
dot_product, -min(self.num_to_return, len(self.vectors))
)[-self.num_to_return :]
# Sort indices by distance
indices = indices[np.argsort(dot_product[indices])]
return [
Document(
page_content=self.texts[idx],
metadata={"index": idx},
)
for idx in indices
]
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
vector_query = np.array(self.embeddings.embed_query(query))
return self._get_relevant_documents_from_query_vector(vector_query)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
query_emb = await self.embeddings.aembed_query(query)
return self._get_relevant_documents_from_query_vector(np.array(query_emb))
```
The retriever can be instantiated through the class method `from_texts`. It embeds the texts and stores them in a numpy array. To look up documents, it embeds the query and finds the most similar documents using a simple dot product distance.
Once the retriever is implemented, you can use it like any other retriever in LangChain.
```python
retriever = NumpyRetriever.from_texts(texts= ["hello world", "goodbye world"])
```
You can then use the retriever to get relevant documents.
```python
retriever.get_relevant_documents("Hi there!")
# [Document(page_content='hello world', metadata={'index': 0})]
```
```python
retriever.get_relevant_documents("Bye!")
# [Document(page_content='goodbye world', metadata={'index': 1})]
```

@ -30,6 +30,7 @@ class BaseMetadataCallbackHandler:
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
@ -52,6 +53,7 @@ class BaseMetadataCallbackHandler:
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0
@ -86,6 +88,11 @@ class BaseMetadataCallbackHandler:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def get_custom_callback_meta(self) -> Dict[str, Any]:
return {
"step": self.step,

@ -1,15 +1,34 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
)
from langchain.schema import AgentAction, AgentFinish, BaseMessage, Document, LLMResult
class RetrieverManagerMixin:
"""Mixin for Retriever callbacks."""
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever ends running."""
class LLMManagerMixin:
@ -144,6 +163,16 @@ class CallbackManagerMixin:
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
@ -187,6 +216,7 @@ class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
RetrieverManagerMixin,
CallbackManagerMixin,
RunManagerMixin,
):
@ -211,6 +241,11 @@ class BaseCallbackHandler(
"""Whether to ignore agent callbacks."""
return False
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return False
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
@ -371,6 +406,36 @@ class AsyncCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run on agent end."""
async def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
async def on_retriever_end(
self,
documents: Sequence[Document],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever end."""
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on retriever error."""
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that can be used to handle callbacks from LangChain."""

@ -14,6 +14,7 @@ from typing import (
Generator,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
@ -27,6 +28,7 @@ from langchain.callbacks.base import (
BaseCallbackManager,
ChainManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
@ -40,6 +42,7 @@ from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
Document,
LLMResult,
get_buffer_string,
)
@ -899,6 +902,97 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
)
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
"""Callback manager for retriever run."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
def on_retriever_end(
self,
documents: Sequence[Document],
**kwargs: Any,
) -> None:
"""Run when retriever ends running."""
_handle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
documents,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
_handle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForRetrieverRun(
AsyncRunManager,
RetrieverManagerMixin,
):
"""Async callback manager for retriever run."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
"""Run when retriever ends running."""
await _ahandle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
documents,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
await _ahandle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
@ -1077,6 +1171,36 @@ class CallbackManager(BaseCallbackManager):
inheritable_tags=self.inheritable_tags,
)
def on_retriever_start(
self,
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",
query,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return CallbackManagerForRetrieverRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
@classmethod
def configure(
cls,
@ -1313,6 +1437,36 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_tags=self.inheritable_tags,
)
async def on_retriever_start(
self,
query: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",
query,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForRetrieverRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
@classmethod
def configure(
cls,

@ -4,12 +4,12 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.schema import LLMResult
from langchain.schema import Document, LLMResult
logger = logging.getLogger(__name__)
@ -265,6 +265,65 @@ class BaseTracer(BaseCallbackHandler, ABC):
self._end_trace(tool_run)
self._on_tool_error(tool_run)
def on_retriever_start(
self,
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when Retriever starts running."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
retrieval_run = Run(
id=run_id,
name="Retriever",
parent_run_id=parent_run_id,
inputs={"query": query},
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.retriever,
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Run when Retriever errors."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
raise TracerException("No retriever Run found to be traced")
retrieval_run.error = repr(error)
retrieval_run.end_time = datetime.utcnow()
self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run)
def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None:
"""Run when Retriever ends running."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
raise TracerException("No retriever Run found to be traced")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
return self
@ -302,3 +361,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _on_chat_model_start(self, run: Run) -> None:
"""Process the Chat Model Run upon start."""
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""

@ -180,6 +180,24 @@ class LangChainTracer(BaseTracer):
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start."""
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run."""
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error."""
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
futures = list(self._futures)

@ -119,6 +119,7 @@ class BaseMetadataCallbackHandler:
ignore_llm_ (bool): Whether to ignore llm callbacks.
ignore_chain_ (bool): Whether to ignore chain callbacks.
ignore_agent_ (bool): Whether to ignore agent callbacks.
ignore_retriever_ (bool): Whether to ignore retriever callbacks.
always_verbose_ (bool): Whether to always be verbose.
chain_starts (int): The number of times the chain start method has been called.
chain_ends (int): The number of times the chain end method has been called.
@ -149,6 +150,7 @@ class BaseMetadataCallbackHandler:
self.ignore_llm_ = False
self.ignore_chain_ = False
self.ignore_agent_ = False
self.ignore_retriever_ = False
self.always_verbose_ = False
self.chain_starts = 0

@ -1,6 +1,7 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from pathlib import Path
@ -87,7 +88,13 @@ class BaseConversationalRetrievalChain(Chain):
return _output_keys
@abstractmethod
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
def _call(
@ -107,7 +114,13 @@ class BaseConversationalRetrievalChain(Chain):
)
else:
new_question = question
docs = self._get_docs(new_question, inputs)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@ -122,7 +135,13 @@ class BaseConversationalRetrievalChain(Chain):
return output
@abstractmethod
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
async def _acall(
@ -141,7 +160,14 @@ class BaseConversationalRetrievalChain(Chain):
)
else:
new_question = question
docs = await self._aget_docs(new_question, inputs)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = await self._aget_docs(new_question, inputs) # type: ignore[call-arg]
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
@ -187,12 +213,30 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
return docs[:num_docs]
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = self.retriever.get_relevant_documents(question)
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
docs = self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
docs = await self.retriever.aget_relevant_documents(question)
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
docs = await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
return self._reduce_tokens_below_limit(docs)
@classmethod
@ -253,14 +297,28 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
)
return values
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
vectordbkwargs = inputs.get("vectordbkwargs", {})
full_kwargs = {**self.search_kwargs, **vectordbkwargs}
return self.vectorstore.similarity_search(
question, k=self.top_k_docs_for_context, **full_kwargs
)
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
question: str,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
raise NotImplementedError("ChatVectorDBChain does not support async")
@classmethod

@ -2,6 +2,7 @@
from __future__ import annotations
import inspect
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
@ -115,7 +116,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
return values
@abstractmethod
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs to run questioning over."""
def _call(
@ -124,7 +130,14 @@ class BaseQAWithSourcesChain(Chain, ABC):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = self._get_docs(inputs)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(inputs, run_manager=_run_manager)
else:
docs = self._get_docs(inputs) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
@ -141,7 +154,12 @@ class BaseQAWithSourcesChain(Chain, ABC):
return result
@abstractmethod
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs to run questioning over."""
async def _acall(
@ -150,7 +168,13 @@ class BaseQAWithSourcesChain(Chain, ABC):
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = await self._aget_docs(inputs)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(inputs, run_manager=_run_manager)
else:
docs = await self._aget_docs(inputs) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
@ -180,10 +204,22 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
"""
return [self.input_docs_key, self.question_key]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs to run questioning over."""
return inputs.pop(self.input_docs_key)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self,
inputs: Dict[str, Any],
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs to run questioning over."""
return inputs.pop(self.input_docs_key)
@property

@ -1,4 +1,6 @@
"""Load question answering with sources chains."""
from __future__ import annotations
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
@ -13,7 +15,9 @@ from langchain.chains.qa_with_sources import (
refine_prompts,
stuff_prompt,
)
from langchain.chains.question_answering import map_rerank_prompt
from langchain.chains.question_answering.map_rerank_prompt import (
PROMPT as MAP_RERANK_PROMPT,
)
from langchain.prompts.base import BasePromptTemplate
@ -28,7 +32,7 @@ class LoadingCallable(Protocol):
def _load_map_rerank_chain(
llm: BaseLanguageModel,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
prompt: BasePromptTemplate = MAP_RERANK_PROMPT,
verbose: bool = False,
document_variable_name: str = "context",
rank_key: str = "score",

@ -4,6 +4,10 @@ from typing import Any, Dict, List
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
@ -40,12 +44,20 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
question = inputs[self.question_key]
docs = self.retriever.get_relevant_documents(question)
docs = self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
question = inputs[self.question_key]
docs = await self.retriever.aget_relevant_documents(question)
docs = await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
return self._reduce_tokens_below_limit(docs)

@ -5,6 +5,10 @@ from typing import Any, Dict, List
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
@ -45,14 +49,18 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
def _get_docs(
self, inputs: Dict[str, Any], *, run_manager: CallbackManagerForChainRun
) -> List[Document]:
question = inputs[self.question_key]
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
)
return self._reduce_tokens_below_limit(docs)
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
async def _aget_docs(
self, inputs: Dict[str, Any], *, run_manager: AsyncCallbackManagerForChainRun
) -> List[Document]:
raise NotImplementedError("VectorDBQAWithSourcesChain does not support async")
@root_validator()

@ -12,10 +12,12 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import (
map_reduce_prompt,
map_rerank_prompt,
refine_prompts,
stuff_prompt,
)
from langchain.chains.question_answering.map_rerank_prompt import (
PROMPT as MAP_RERANK_PROMPT,
)
from langchain.prompts.base import BasePromptTemplate
@ -30,7 +32,7 @@ class LoadingCallable(Protocol):
def _load_map_rerank_chain(
llm: BaseLanguageModel,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
prompt: BasePromptTemplate = MAP_RERANK_PROMPT,
verbose: bool = False,
document_variable_name: str = "context",
rank_key: str = "score",

@ -1,6 +1,7 @@
"""Chain for question-answering against a vector database."""
from __future__ import annotations
import inspect
import warnings
from abc import abstractmethod
from typing import Any, Dict, List, Optional
@ -94,7 +95,12 @@ class BaseRetrievalQA(Chain):
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
@abstractmethod
def _get_docs(self, question: str) -> List[Document]:
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get documents to do question answering over."""
def _call(
@ -115,8 +121,13 @@ class BaseRetrievalQA(Chain):
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = self._get_docs(question)
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:
docs = self._get_docs(question, run_manager=_run_manager)
else:
docs = self._get_docs(question) # type: ignore[call-arg]
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
@ -127,7 +138,12 @@ class BaseRetrievalQA(Chain):
return {self.output_key: answer}
@abstractmethod
async def _aget_docs(self, question: str) -> List[Document]:
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get documents to do question answering over."""
async def _acall(
@ -148,8 +164,13 @@ class BaseRetrievalQA(Chain):
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = await self._aget_docs(question)
accepts_run_manager = (
"run_manager" in inspect.signature(self._aget_docs).parameters
)
if accepts_run_manager:
docs = await self._aget_docs(question, run_manager=_run_manager)
else:
docs = await self._aget_docs(question) # type: ignore[call-arg]
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
@ -177,11 +198,27 @@ class RetrievalQA(BaseRetrievalQA):
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(self, question: str) -> List[Document]:
return self.retriever.get_relevant_documents(question)
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return self.retriever.get_relevant_documents(
question, callbacks=run_manager.get_child()
)
async def _aget_docs(self, question: str) -> List[Document]:
return await self.retriever.aget_relevant_documents(question)
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
@property
def _chain_type(self) -> str:
@ -218,7 +255,13 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def _get_docs(self, question: str) -> List[Document]:
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
@ -231,7 +274,13 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def _aget_docs(self, question: str) -> List[Document]:
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
raise NotImplementedError("VectorDBQA does not support async")
@property

@ -697,7 +697,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.11.2"
}
},
"nbformat": 4,

@ -16,7 +16,7 @@ from langchain.retrievers.metal import MetalRetriever
from langchain.retrievers.milvus import MilvusRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever
from langchain.retrievers.pupmed import PubMedRetriever
from langchain.retrievers.pubmed import PubMedRetriever
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.svm import SVMRetriever

@ -1,5 +1,9 @@
from typing import List
from typing import Any, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
from langchain.utilities.arxiv import ArxivAPIWrapper
@ -11,8 +15,20 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
It uses all ArxivAPIWrapper arguments without any change.
"""
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.load(query=query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,13 +1,18 @@
"""Retriever wrapper for Azure Cognitive Search."""
from __future__ import annotations
import json
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
import aiohttp
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env
@ -81,7 +86,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
return response_json["value"]
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
search_results = self._search(query)
return [
@ -89,7 +100,13 @@ class AzureCognitiveSearchRetriever(BaseRetriever, BaseModel):
for result in search_results
]
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
search_results = await self._asearch(query)
return [

@ -1,11 +1,15 @@
from __future__ import annotations
from typing import List, Optional
from typing import Any, List, Optional
import aiohttp
import requests
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
@ -21,7 +25,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
url, json, headers = self._create_request(query)
response = requests.post(url, json=json, headers=headers)
results = response.json()["results"][0]["results"]
@ -34,7 +44,13 @@ class ChatGPTPluginRetriever(BaseRetriever, BaseModel):
docs.append(Document(page_content=content, metadata=metadata))
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
url, json, headers = self._create_request(query)
if not self.aiosession:

@ -1,8 +1,13 @@
"""Retriever that wraps a base retriever and filters the results."""
from typing import List
from typing import Any, List
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.retrievers.document_compressors.base import (
BaseDocumentCompressor,
)
@ -24,7 +29,13 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
@ -33,14 +44,24 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
Returns:
Sequence of relevant documents
"""
docs = self.base_retriever.get_relevant_documents(query)
docs = self.base_retriever.get_relevant_documents(
query, callbacks=run_manager.get_child(), **kwargs
)
if docs:
compressed_docs = self.base_compressor.compress_documents(docs, query)
compressed_docs = self.base_compressor.compress_documents(
docs, query, callbacks=run_manager.get_child()
)
return list(compressed_docs)
else:
return []
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
@ -49,10 +70,12 @@ class ContextualCompressionRetriever(BaseRetriever, BaseModel):
Returns:
List of relevant documents
"""
docs = await self.base_retriever.aget_relevant_documents(query)
docs = await self.base_retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child(), **kwargs
)
if docs:
compressed_docs = await self.base_compressor.acompress_documents(
docs, query
docs, query, callbacks=run_manager.get_child()
)
return list(compressed_docs)
else:

@ -1,8 +1,12 @@
from typing import List, Optional
from typing import Any, List, Optional
import aiohttp
import requests
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
@ -23,7 +27,13 @@ class DataberryRetriever(BaseRetriever):
self.api_key = api_key
self.top_k = top_k
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
response = requests.post(
self.datastore_url,
json={
@ -48,7 +58,13 @@ class DataberryRetriever(BaseRetriever):
for r in data["results"]
]
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST",

@ -4,6 +4,10 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.utils import maximal_marginal_relevance
@ -49,7 +53,12 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
Args:
@ -201,5 +210,10 @@ class DocArrayRetriever(BaseRetriever, BaseModel):
return lc_doc
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,9 +1,11 @@
"""Interface for retrieved document compressors."""
from abc import ABC, abstractmethod
from typing import List, Sequence, Union
from inspect import signature
from typing import List, Optional, Sequence, Union
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseDocumentTransformer, Document
@ -12,13 +14,19 @@ class BaseDocumentCompressor(BaseModel, ABC):
@abstractmethod
def compress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
@abstractmethod
async def acompress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
@ -35,12 +43,26 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
arbitrary_types_allowed = True
def compress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Transform a list of documents."""
for _transformer in self.transformers:
if isinstance(_transformer, BaseDocumentCompressor):
documents = _transformer.compress_documents(documents, query)
accepts_callbacks = (
signature(_transformer.compress_documents).parameters.get(
"callbacks"
)
is not None
)
if accepts_callbacks:
documents = _transformer.compress_documents(
documents, query, callbacks=callbacks
)
else:
documents = _transformer.compress_documents(documents, query)
elif isinstance(_transformer, BaseDocumentTransformer):
documents = _transformer.transform_documents(documents)
else:
@ -48,12 +70,26 @@ class DocumentCompressorPipeline(BaseDocumentCompressor):
return documents
async def acompress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context."""
for _transformer in self.transformers:
if isinstance(_transformer, BaseDocumentCompressor):
documents = await _transformer.acompress_documents(documents, query)
accepts_callbacks = (
signature(_transformer.acompress_documents).parameters.get(
"callbacks"
)
is not None
)
if accepts_callbacks:
documents = await _transformer.acompress_documents(
documents, query, callbacks=callbacks
)
else:
documents = await _transformer.acompress_documents(documents, query)
elif isinstance(_transformer, BaseDocumentTransformer):
documents = await _transformer.atransform_documents(documents)
else:

@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Optional, Sequence
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.retrievers.document_compressors.chain_extract_prompt import (
prompt_template,
@ -48,25 +49,33 @@ class LLMChainExtractor(BaseDocumentCompressor):
"""Callable for constructing the chain input from the query and a Document."""
def compress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress page content of raw documents."""
compressed_docs = []
for doc in documents:
_input = self.get_input(query, doc)
output = self.llm_chain.predict_and_parse(**_input)
output = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
if len(output) == 0:
continue
compressed_docs.append(Document(page_content=output, metadata=doc.metadata))
return compressed_docs
async def acompress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress page content of raw documents asynchronously."""
outputs = await asyncio.gather(
*[
self.llm_chain.apredict_and_parse(**self.get_input(query, doc))
self.llm_chain.apredict_and_parse(
**self.get_input(query, doc), callbacks=callbacks
)
for doc in documents
]
)

@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, Optional, Sequence
from langchain import BasePromptTemplate, LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.retrievers.document_compressors.chain_filter_prompt import (
@ -35,19 +36,27 @@ class LLMChainFilter(BaseDocumentCompressor):
"""Callable for constructing the chain input from the query and a Document."""
def compress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
for doc in documents:
_input = self.get_input(query, doc)
include_doc = self.llm_chain.predict_and_parse(**_input)
include_doc = self.llm_chain.predict_and_parse(
**_input, callbacks=callbacks
)
if include_doc:
filtered_docs.append(doc)
return filtered_docs
async def acompress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents."""
raise NotImplementedError

@ -1,9 +1,10 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Sequence
from typing import TYPE_CHECKING, Dict, Optional, Sequence
from pydantic import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.schema import Document
from langchain.utils import get_from_dict_or_env
@ -48,7 +49,10 @@ class CohereRerank(BaseDocumentCompressor):
return values
def compress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
if len(documents) == 0: # to avoid empty api call
return []
@ -65,6 +69,9 @@ class CohereRerank(BaseDocumentCompressor):
return final_results
async def acompress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
raise NotImplementedError

@ -4,6 +4,7 @@ from typing import Callable, Dict, Optional, Sequence
import numpy as np
from pydantic import root_validator
from langchain.callbacks.manager import Callbacks
from langchain.document_transformers import (
_get_embeddings_from_stateful_docs,
get_stateful_documents,
@ -44,7 +45,10 @@ class EmbeddingsFilter(BaseDocumentCompressor):
return values
def compress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter documents based on similarity of their embeddings to the query."""
stateful_documents = get_stateful_documents(documents)
@ -64,7 +68,10 @@ class EmbeddingsFilter(BaseDocumentCompressor):
return [stateful_documents[i] for i in included_idxs]
async def acompress_documents(
self, documents: Sequence[Document], query: str
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Filter down documents."""
raise NotImplementedError

@ -1,9 +1,14 @@
"""Wrapper around Elasticsearch vector database."""
from __future__ import annotations
import uuid
from typing import Any, Iterable, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
@ -111,7 +116,13 @@ class ElasticSearchBM25Retriever(BaseRetriever):
self.client.indices.refresh(index=self.index_name)
return ids
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
query_dict = {"query": {"match": {"content": query}}}
res = self.client.search(index=self.index_name, body=query_dict)
@ -120,5 +131,11 @@ class ElasticSearchBM25Retriever(BaseRetriever):
docs.append(Document(page_content=r["_source"]["content"]))
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -3,6 +3,10 @@ from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
@ -257,7 +261,12 @@ class AmazonKendraRetriever(BaseRetriever):
docs = r_result.get_top_k_docs(top_k)
return docs
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Run search on Kendra index and get top k documents
Example:
@ -269,5 +278,10 @@ class AmazonKendraRetriever(BaseRetriever):
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("Async version is not implemented for Kendra yet.")

@ -10,6 +10,10 @@ from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@ -51,7 +55,13 @@ class KNNRetriever(BaseRetriever, BaseModel):
index = create_index(texts, embeddings)
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
query_embeds = np.array(self.embeddings.embed_query(query))
# calc L2 norm
index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True))
@ -73,5 +83,11 @@ class KNNRetriever(BaseRetriever, BaseModel):
]
return top_k_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,7 +1,11 @@
from typing import Any, Dict, List, cast
from typing import Any, Dict, List, Optional, cast
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
@ -11,7 +15,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
index: Any
query_kwargs: Dict = Field(default_factory=dict)
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.indices.base import BaseGPTIndex
@ -33,7 +43,13 @@ class LlamaIndexRetriever(BaseRetriever, BaseModel):
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("LlamaIndexRetriever does not support async")
@ -43,7 +59,13 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
graph: Any
query_configs: List[Dict] = Field(default_factory=list)
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query."""
try:
from llama_index.composability.graph import (
@ -73,5 +95,11 @@ class LlamaIndexGraphRetriever(BaseRetriever, BaseModel):
)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")

@ -1,5 +1,9 @@
from typing import List
from typing import Any, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
@ -24,7 +28,12 @@ class MergerRetriever(BaseRetriever):
self.retrievers = retrievers
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Get the relevant documents for a given query.
@ -36,11 +45,16 @@ class MergerRetriever(BaseRetriever):
"""
# Merge the results of the retrievers.
merged_documents = self.merge_documents(query)
merged_documents = self.merge_documents(query, run_manager)
return merged_documents
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
@ -52,11 +66,13 @@ class MergerRetriever(BaseRetriever):
"""
# Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query)
merged_documents = await self.amerge_documents(query, run_manager)
return merged_documents
def merge_documents(self, query: str) -> List[Document]:
def merge_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Merge the results of the retrievers.
@ -69,7 +85,10 @@ class MergerRetriever(BaseRetriever):
# Get the results of all retrievers.
retriever_docs = [
retriever.get_relevant_documents(query) for retriever in self.retrievers
retriever.get_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
)
for i, retriever in enumerate(self.retrievers)
]
# Merge the results of the retrievers.
@ -82,7 +101,9 @@ class MergerRetriever(BaseRetriever):
return merged_documents
async def amerge_documents(self, query: str) -> List[Document]:
async def amerge_documents(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""
Asynchronously merge the results of the retrievers.
@ -95,8 +116,10 @@ class MergerRetriever(BaseRetriever):
# Get the results of all retrievers.
retriever_docs = [
await retriever.aget_relevant_documents(query)
for retriever in self.retrievers
await retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
)
for i, retriever in enumerate(self.retrievers)
]
# Merge the results of the retrievers.

@ -1,5 +1,9 @@
from typing import Any, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
@ -17,7 +21,13 @@ class MetalRetriever(BaseRetriever):
self.client: Metal = client
self.params = params or {}
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
results = self.client.search({"text": query}, **self.params)
final_results = []
for r in results["data"]:
@ -25,5 +35,11 @@ class MetalRetriever(BaseRetriever):
final_results.append(Document(page_content=r["text"], metadata=metadata))
return final_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -2,6 +2,10 @@
import warnings
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.milvus import Milvus
@ -39,10 +43,24 @@ class MilvusRetriever(BaseRetriever):
"""
self.store.add_texts(texts, metadatas)
def get_relevant_documents(self, query: str) -> List[Document]:
return self.retriever.get_relevant_documents(query)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.retriever.get_relevant_documents(
query, run_manager=run_manager.get_child(), **kwargs
)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,8 +1,12 @@
import logging
from typing import List
from typing import Any, List
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.output_parsers.pydantic import PydanticOutputParser
@ -91,7 +95,12 @@ class MultiQueryRetriever(BaseRetriever):
parser_key=parser_key,
)
def get_relevant_documents(self, question: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get relevated documents given a user query.
Args:
@ -100,15 +109,22 @@ class MultiQueryRetriever(BaseRetriever):
Returns:
Unique union of relevant documents from all generated queries
"""
queries = self.generate_queries(question)
documents = self.retrieve_documents(queries)
queries = self.generate_queries(query, run_manager)
documents = self.retrieve_documents(queries, run_manager)
unique_documents = self.unique_union(documents)
return unique_documents
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def generate_queries(self, question: str) -> List[str]:
def generate_queries(
self, question: str, run_manager: CallbackManagerForRetrieverRun
) -> List[str]:
"""Generate queries based upon user input.
Args:
@ -117,13 +133,17 @@ class MultiQueryRetriever(BaseRetriever):
Returns:
List of LLM generated queries that are similar to the user input
"""
response = self.llm_chain({"question": question})
response = self.llm_chain(
{"question": question}, callbacks=run_manager.get_child()
)
lines = getattr(response["text"], self.parser_key, [])
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines
def retrieve_documents(self, queries: List[str]) -> List[Document]:
def retrieve_documents(
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Run all LLM generated queries.
Args:
@ -134,7 +154,9 @@ class MultiQueryRetriever(BaseRetriever):
"""
documents = []
for query in queries:
docs = self.retriever.get_relevant_documents(query)
docs = self.retriever.get_relevant_documents(
query, callbacks=run_manager.get_child()
)
documents.extend(docs)
return documents

@ -1,9 +1,14 @@
"""Taken from: https://docs.pinecone.io/docs/hybrid-search"""
import hashlib
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@ -137,7 +142,13 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
from pinecone_text.hybrid import hybrid_convex_scale
sparse_vec = self.sparse_encoder.encode_queries(query)
@ -162,5 +173,11 @@ class PineconeHybridSearchRetriever(BaseRetriever, BaseModel):
# return search results as json
return final_result
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -0,0 +1,35 @@
"""A retriever that uses PubMed API to retrieve documents."""
from typing import Any, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
from langchain.utilities.pupmed import PubMedAPIWrapper
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
"""
It is effectively a wrapper for PubMedAPIWrapper.
It wraps load() to get_relevant_documents().
It uses all PubMedAPIWrapper arguments without any change.
"""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.load_docs(query=query)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,18 +1,5 @@
from typing import List
from langchain.retrievers.pubmed import PubMedRetriever
from langchain.schema import BaseRetriever, Document
from langchain.utilities.pupmed import PubMedAPIWrapper
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
"""
It is effectively a wrapper for PubMedAPIWrapper.
It wraps load() to get_relevant_documents().
It uses all PubMedAPIWrapper arguments without any change.
"""
def get_relevant_documents(self, query: str) -> List[Document]:
return self.load_docs(query=query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
__all__ = [
"PubMedRetriever",
]

@ -1,9 +1,13 @@
from typing import List, Optional
from typing import Any, List, Optional
import aiohttp
import requests
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
@ -15,7 +19,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
page_content_key: str = "page_content"
metadata_key: str = "metadata"
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
response = requests.post(
self.url, json={self.input_key: query}, headers=self.headers
)
@ -27,7 +37,13 @@ class RemoteLangChainRetriever(BaseRetriever, BaseModel):
for r in result[self.response_key]
]
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST", self.url, headers=self.headers, json={self.input_key: query}

@ -1,11 +1,15 @@
"""Retriever that generates and executes structured queries over its own data source."""
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
@ -79,8 +83,12 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
)
return values
def get_relevant_documents(
self, query: str, callbacks: Callbacks = None
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Get documents relevant for a query.
@ -93,7 +101,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast(
StructuredQuery,
self.llm_chain.predict_and_parse(callbacks=callbacks, **inputs),
self.llm_chain.predict_and_parse(
callbacks=run_manager.get_child(), **inputs
),
)
if self.verbose:
print(structured_query)
@ -110,7 +120,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun],
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
@classmethod

@ -10,6 +10,10 @@ from typing import Any, List, Optional
import numpy as np
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@ -50,7 +54,13 @@ class SVMRetriever(BaseRetriever, BaseModel):
index = create_index(texts, embeddings)
return cls(embeddings=embeddings, index=index, texts=texts, **kwargs)
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
from sklearn import svm
query_embeds = np.array(self.embeddings.embed_query(query))
@ -87,5 +97,11 @@ class SVMRetriever(BaseRetriever, BaseModel):
top_k_results.append(Document(page_content=self.texts[row - 1]))
return top_k_results
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -2,12 +2,17 @@
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from pydantic import BaseModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
@ -58,7 +63,13 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
texts=texts, tfidf_params=tfidf_params, metadatas=metadatas, **kwargs
)
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
from sklearn.metrics.pairwise import cosine_similarity
query_vec = self.vectorizer.transform(
@ -70,5 +81,11 @@ class TFIDFRetriever(BaseRetriever, BaseModel):
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,10 +1,15 @@
"""Retriever that combines embedding similarity with recency in retrieving values."""
import datetime
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
@ -80,7 +85,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
results[buffer_idx] = (doc, relevance)
return results
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Return documents that are relevant to the query."""
current_time = datetime.datetime.now()
docs_and_scores = {
@ -103,7 +114,13 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever, BaseModel):
result.append(buffered_doc)
return result
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Return documents that are relevant to the query."""
raise NotImplementedError

@ -1,9 +1,14 @@
"""Wrapper for retrieving documents from Vespa."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
if TYPE_CHECKING:
@ -59,12 +64,24 @@ class VespaRetriever(BaseRetriever):
docs.append(Document(page_content=page_content, metadata=metadata))
return docs
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
body = self._query_body.copy()
body["query"] = query
return self._query(body)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def get_relevant_documents_with_filter(

@ -1,4 +1,5 @@
"""Wrapper around weaviate vector database."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
@ -6,6 +7,10 @@ from uuid import uuid4
from pydantic import Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
@ -82,8 +87,13 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
ids.append(_id)
return ids
def get_relevant_documents(
self, query: str, where_filter: Optional[Dict[str, object]] = None
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
**kwargs: Any,
) -> List[Document]:
"""Look up similar documents in Weaviate."""
query_obj = self._client.query.get(self._index_name, self._query_attrs)
@ -101,7 +111,12 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
docs.append(Document(page_content=text, metadata=res))
return docs
async def aget_relevant_documents(
self, query: str, where_filter: Optional[Dict[str, object]] = None
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,5 +1,9 @@
from typing import List
from typing import Any, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
from langchain.utilities.wikipedia import WikipediaAPIWrapper
@ -11,8 +15,20 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
It uses all WikipediaAPIWrapper arguments without any change.
"""
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.load(query=query)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,7 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
if TYPE_CHECKING:
@ -54,8 +58,13 @@ class ZepRetriever(BaseRetriever):
if r.message
]
def get_relevant_documents(
self, query: str, metadata: Optional[Dict] = None
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
metadata: Optional[Dict] = None,
**kwargs: Any,
) -> List[Document]:
from zep_python import MemorySearchPayload
@ -69,8 +78,13 @@ class ZepRetriever(BaseRetriever):
return self._search_result_to_doc(results)
async def aget_relevant_documents(
self, query: str, metadata: Optional[Dict] = None
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
metadata: Optional[Dict] = None,
**kwargs: Any,
) -> List[Document]:
from zep_python import MemorySearchPayload

@ -2,6 +2,10 @@
import warnings
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.zilliz import Zilliz
@ -39,10 +43,24 @@ class ZillizRetriever(BaseRetriever):
"""
self.store.add_texts(texts, metadatas)
def get_relevant_documents(self, query: str) -> List[Document]:
return self.retriever.get_relevant_documents(query)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
return self.retriever.get_relevant_documents(
query, run_manager=run_manager.get_child(), **kwargs
)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError

@ -1,9 +1,12 @@
"""Common schema objects."""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from inspect import signature
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
@ -20,6 +23,13 @@ from pydantic import BaseModel, Field, root_validator
from langchain.load.serializable import Serializable
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
Callbacks,
)
RUN_KEY = "__run"
@ -360,29 +370,150 @@ class Document(Serializable):
class BaseRetriever(ABC):
"""Base interface for retrievers."""
"""Base interface for a retriever."""
_new_arg_supported: bool = False
_expects_other_args: bool = False
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
# Version upgrade for old retrievers that implemented the public
# methods directly.
if cls.get_relevant_documents != BaseRetriever.get_relevant_documents:
warnings.warn(
"Retrievers must implement abstract `_get_relevant_documents` method"
" instead of `get_relevant_documents`",
DeprecationWarning,
)
swap = cls.get_relevant_documents
cls.get_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.get_relevant_documents
)
cls._get_relevant_documents = swap # type: ignore[assignment]
if (
hasattr(cls, "aget_relevant_documents")
and cls.aget_relevant_documents != BaseRetriever.aget_relevant_documents
):
warnings.warn(
"Retrievers must implement abstract `_aget_relevant_documents` method"
" instead of `aget_relevant_documents`",
DeprecationWarning,
)
aswap = cls.aget_relevant_documents
cls.aget_relevant_documents = ( # type: ignore[assignment]
BaseRetriever.aget_relevant_documents
)
cls._aget_relevant_documents = aswap # type: ignore[assignment]
parameters = signature(cls._get_relevant_documents).parameters
cls._new_arg_supported = parameters.get("run_manager") is not None
# If a V1 retriever broke the interface and expects additional arguments
cls._expects_other_args = (not cls._new_arg_supported) and len(parameters) > 2
@abstractmethod
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
@abstractmethod
async def aget_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
def get_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
if self._new_arg_supported:
result = self._get_relevant_documents(
query, run_manager=run_manager, **kwargs
)
elif self._expects_other_args:
result = self._get_relevant_documents(query, **kwargs)
else:
result = self._get_relevant_documents(query) # type: ignore[call-arg]
except Exception as e:
run_manager.on_retriever_error(e)
raise e
else:
run_manager.on_retriever_end(
result,
**kwargs,
)
return result
async def aget_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
)
run_manager = await callback_manager.on_retriever_start(
query,
**kwargs,
)
try:
if self._new_arg_supported:
result = await self._aget_relevant_documents(
query, run_manager=run_manager, **kwargs
)
elif self._expects_other_args:
result = await self._aget_relevant_documents(query, **kwargs)
else:
result = await self._aget_relevant_documents(
query, # type: ignore[call-arg]
)
except Exception as e:
await run_manager.on_retriever_error(e)
raise e
else:
await run_manager.on_retriever_end(
result,
**kwargs,
)
return result
# For backwards compatibility

@ -20,6 +20,10 @@ from typing import (
import numpy as np
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
@ -490,7 +494,12 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.vector_search(query, k=self.k)
elif self.search_type == "hybrid":
@ -501,7 +510,12 @@ class AzureSearchVectorStoreRetriever(BaseRetriever, BaseModel):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError(
"AzureSearchVectorStoreRetriever does not support async"
)

@ -1,4 +1,5 @@
"""Interface for vector stores."""
from __future__ import annotations
import asyncio
@ -20,6 +21,10 @@ from typing import (
from pydantic import BaseModel, Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
@ -402,7 +407,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
@ -420,7 +431,13 @@ class VectorStoreRetriever(BaseRetriever, BaseModel):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs

@ -1,4 +1,5 @@
"""Wrapper around Redis vector database."""
from __future__ import annotations
import json
@ -21,6 +22,10 @@ from typing import (
import numpy as np
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
@ -614,7 +619,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
raise ValueError(f"search_type of {search_type} not allowed.")
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
elif self.search_type == "similarity_limit":
@ -625,7 +636,13 @@ class RedisVectorStoreRetriever(VectorStoreRetriever, BaseModel):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError("RedisVectorStoreRetriever does not support async")
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:

@ -1,21 +1,17 @@
"""Wrapper around SingleStore DB."""
from __future__ import annotations
import enum
import json
from typing import (
Any,
ClassVar,
Collection,
Iterable,
List,
Optional,
Tuple,
Type,
)
from typing import Any, ClassVar, Collection, Iterable, List, Optional, Tuple, Type
from sqlalchemy.pool import QueuePool
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
@ -454,14 +450,26 @@ class SingleStoreDBRetriever(VectorStoreRetriever):
k: int = 4
allowed_search_types: ClassVar[Collection[str]] = ("similarity",)
def get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, k=self.k)
else:
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: Optional[AsyncCallbackManagerForRetrieverRun] = None,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError(
"SingleStoreDBVectorStoreRetriever does not support async"
)

544
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -19,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel):
ignore_llm_: bool = False
ignore_chain_: bool = False
ignore_agent_: bool = False
ignore_retriever_: bool = False
ignore_chat_model_: bool = False
# add finer-grained counters for easier debugging of failing tests
@ -32,6 +33,9 @@ class BaseFakeCallbackHandler(BaseModel):
agent_actions: int = 0
agent_ends: int = 0
chat_model_starts: int = 0
retriever_starts: int = 0
retriever_ends: int = 0
retriever_errors: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
@ -52,7 +56,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.llm_streams += 1
def on_chain_start_common(self) -> None:
print("CHAIN START")
("CHAIN START")
self.chain_starts += 1
self.starts += 1
@ -91,6 +95,18 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
def on_text_common(self) -> None:
self.text += 1
def on_retriever_start_common(self) -> None:
self.starts += 1
self.retriever_starts += 1
def on_retriever_end_common(self) -> None:
self.ends += 1
self.retriever_ends += 1
def on_retriever_error_common(self) -> None:
self.errors += 1
self.retriever_errors += 1
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Fake callback handler for testing."""
@ -110,6 +126,11 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def on_llm_start(
self,
*args: Any,
@ -201,6 +222,27 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_text_common()
def on_retriever_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_start_common()
def on_retriever_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_end_common()
def on_retriever_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
return self

@ -141,6 +141,22 @@ def test_ignore_agent() -> None:
assert handler2.errors == 1
def test_ignore_retriever() -> None:
"""Test the ignore retriever param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_retriever_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
run_manager = manager.on_retriever_start("")
run_manager.on_retriever_end([])
run_manager.on_retriever_error(Exception())
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0
assert handler2.starts == 1
assert handler2.ends == 1
assert handler2.errors == 1
@pytest.mark.asyncio
async def test_async_callback_manager() -> None:
"""Test the AsyncCallbackManager."""

@ -0,0 +1,220 @@
"""Test Base Retriever logic."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import pytest
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.fixture
def fake_retriever_v1() -> BaseRetriever:
with pytest.warns(
DeprecationWarning,
match="Retrievers must implement abstract "
"`_get_relevant_documents` method instead of `get_relevant_documents`",
):
class FakeRetrieverV1(BaseRetriever):
def get_relevant_documents( # type: ignore[override]
self,
query: str,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV1)
return [
Document(page_content=query, metadata={"uuid": "1234"}),
]
async def aget_relevant_documents( # type: ignore[override]
self,
query: str,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV1)
return [
Document(
page_content=f"Async query {query}", metadata={"uuid": "1234"}
),
]
return FakeRetrieverV1() # type: ignore[abstract]
def test_fake_retriever_v1_upgrade(fake_retriever_v1: BaseRetriever) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v1._new_arg_supported is False
assert fake_retriever_v1._expects_other_args is False
results: List[Document] = fake_retriever_v1.get_relevant_documents(
"Foo", callbacks=[callbacks]
)
assert results[0].page_content == "Foo"
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
@pytest.mark.asyncio
async def test_fake_retriever_v1_upgrade_async(
fake_retriever_v1: BaseRetriever,
) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v1._new_arg_supported is False
assert fake_retriever_v1._expects_other_args is False
results: List[Document] = await fake_retriever_v1.aget_relevant_documents(
"Foo", callbacks=[callbacks]
)
assert results[0].page_content == "Async query Foo"
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
@pytest.fixture
def fake_retriever_v1_with_kwargs() -> BaseRetriever:
# Test for things like the Weaviate V1 Retriever.
with pytest.warns(
DeprecationWarning,
match="Retrievers must implement abstract "
"`_get_relevant_documents` method instead of `get_relevant_documents`",
):
class FakeRetrieverV1(BaseRetriever):
def get_relevant_documents( # type: ignore[override]
self, query: str, where_filter: Optional[Dict[str, object]] = None
) -> List[Document]:
assert isinstance(self, FakeRetrieverV1)
return [
Document(page_content=query, metadata=where_filter or {}),
]
async def aget_relevant_documents( # type: ignore[override]
self, query: str, where_filter: Optional[Dict[str, object]] = None
) -> List[Document]:
assert isinstance(self, FakeRetrieverV1)
return [
Document(
page_content=f"Async query {query}", metadata=where_filter or {}
),
]
return FakeRetrieverV1() # type: ignore[abstract]
def test_fake_retriever_v1_with_kwargs_upgrade(
fake_retriever_v1_with_kwargs: BaseRetriever,
) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v1_with_kwargs._new_arg_supported is False
assert fake_retriever_v1_with_kwargs._expects_other_args is True
results: List[Document] = fake_retriever_v1_with_kwargs.get_relevant_documents(
"Foo", callbacks=[callbacks], where_filter={"foo": "bar"}
)
assert results[0].page_content == "Foo"
assert results[0].metadata == {"foo": "bar"}
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
@pytest.mark.asyncio
async def test_fake_retriever_v1_with_kwargs_upgrade_async(
fake_retriever_v1_with_kwargs: BaseRetriever,
) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v1_with_kwargs._new_arg_supported is False
assert fake_retriever_v1_with_kwargs._expects_other_args is True
results: List[
Document
] = await fake_retriever_v1_with_kwargs.aget_relevant_documents(
"Foo", callbacks=[callbacks], where_filter={"foo": "bar"}
)
assert results[0].page_content == "Async query Foo"
assert results[0].metadata == {"foo": "bar"}
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
@pytest.fixture
def fake_retriever_v2() -> BaseRetriever:
class FakeRetrieverV2(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun | None = None,
**kwargs: Any,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
if "throw_error" in kwargs:
raise ValueError("Test error")
return [
Document(page_content=query, metadata=kwargs),
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
**kwargs: Any,
) -> List[Document]:
assert isinstance(self, FakeRetrieverV2)
assert run_manager is not None
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
if "throw_error" in kwargs:
raise ValueError("Test error")
return [
Document(page_content=f"Async query {query}", metadata=kwargs),
]
return FakeRetrieverV2() # type: ignore[abstract]
def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v2._new_arg_supported is True
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
assert results[0].page_content == "Foo"
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
results2 = fake_retriever_v2.get_relevant_documents(
"Foo", callbacks=[callbacks], foo="bar"
)
assert results2[0].metadata == {"foo": "bar"}
with pytest.raises(ValueError, match="Test error"):
fake_retriever_v2.get_relevant_documents(
"Foo", callbacks=[callbacks], throw_error=True
)
assert callbacks.retriever_errors == 1
@pytest.mark.asyncio
async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None:
callbacks = FakeCallbackHandler()
assert fake_retriever_v2._new_arg_supported is True
results = await fake_retriever_v2.aget_relevant_documents(
"Foo", callbacks=[callbacks]
)
assert results[0].page_content == "Async query Foo"
assert callbacks.retriever_starts == 1
assert callbacks.retriever_ends == 1
assert callbacks.retriever_errors == 0
results2 = await fake_retriever_v2.aget_relevant_documents(
"Foo", callbacks=[callbacks], foo="bar"
)
assert results2[0].metadata == {"foo": "bar"}
with pytest.raises(ValueError, match="Test error"):
await fake_retriever_v2.aget_relevant_documents(
"Foo", callbacks=[callbacks], throw_error=True
)
Loading…
Cancel
Save