mirror of https://github.com/hwchase17/langchain
Add New Retriever Interface with Callbacks (#5962)
Handle the new retriever events in a way that (I think) is entirely backwards compatible? Needs more testing for some of the chain changes and all. This creates an entire new run type, however. We could also just treat this as an event within a chain run presumably (same with memory) Adds a subclass initializer that upgrades old retriever implementations to the new schema, along with tests to ensure they work. First commit doesn't upgrade any of our retriever implementations (to show that we can pass the tests along with additional ones testing the upgrade logic). Second commit upgrades the known universe of retrievers in langchain. - [X] Add callback handling methods for retriever start/end/error (open to renaming to 'retrieval' if you want that) - [X] Update BaseRetriever schema to support callbacks - [X] Tests for upgrading old "v1" retrievers for backwards compatibility - [X] Update existing retriever implementations to implement the new interface - [X] Update calls within chains to .{a]get_relevant_documents to pass the child callback manager - [X] Update the notebooks/docs to reflect the new interface - [X] Test notebooks thoroughly Not handled: - Memory pass throughs: retrieval memory doesn't have a parent callback manager passed through the method --------- Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>pull/7005/head
parent
a5b206caf3
commit
b0859c9b18
@ -0,0 +1,162 @@
|
||||
# Implement a Custom Retriever
|
||||
|
||||
In this walkthrough, you will implement a simple custom retriever in LangChain using a simple dot product distance lookup.
|
||||
|
||||
All retrievers inherit from the `BaseRetriever` class and override the following abstract methods:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from langchain.schema import Document
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
@abstractmethod
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
```
|
||||
|
||||
|
||||
The `_get_relevant_documents` and async `_get_relevant_documents` methods can be implemented however you see fit. The `run_manager` is useful if your retriever calls other traceable LangChain primitives like LLMs, chains, or tools.
|
||||
|
||||
|
||||
Below, implement an example that fetches the most similar documents from a list of documents using a numpy array of embeddings.
|
||||
|
||||
|
||||
```python
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class NumpyRetriever(BaseRetriever):
|
||||
"""Retrieves documents from a numpy array."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
vectors: np.ndarray,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
num_to_return: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embeddings = embeddings or OpenAIEmbeddings()
|
||||
self.texts = texts
|
||||
self.vectors = vectors
|
||||
self.num_to_return = num_to_return
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
**kwargs: Any,
|
||||
) -> "NumpyRetriever":
|
||||
embeddings = embeddings or OpenAIEmbeddings()
|
||||
vectors = np.array(embeddings.embed_documents(texts))
|
||||
return cls(texts, vectors, embeddings)
|
||||
|
||||
def _get_relevant_documents_from_query_vector(
|
||||
self, vector_query: np.ndarray
|
||||
) -> List[Document]:
|
||||
dot_product = np.dot(self.vectors, vector_query)
|
||||
# Get the indices of the min 5 documents
|
||||
indices = np.argpartition(
|
||||
dot_product, -min(self.num_to_return, len(self.vectors))
|
||||
)[-self.num_to_return :]
|
||||
# Sort indices by distance
|
||||
indices = indices[np.argsort(dot_product[indices])]
|
||||
return [
|
||||
Document(
|
||||
page_content=self.texts[idx],
|
||||
metadata={"index": idx},
|
||||
)
|
||||
for idx in indices
|
||||
]
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
vector_query = np.array(self.embeddings.embed_query(query))
|
||||
return self._get_relevant_documents_from_query_vector(vector_query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
run_manager: The callbacks handler to use
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
query_emb = await self.embeddings.aembed_query(query)
|
||||
return self._get_relevant_documents_from_query_vector(np.array(query_emb))
|
||||
```
|
||||
|
||||
The retriever can be instantiated through the class method `from_texts`. It embeds the texts and stores them in a numpy array. To look up documents, it embeds the query and finds the most similar documents using a simple dot product distance.
|
||||
Once the retriever is implemented, you can use it like any other retriever in LangChain.
|
||||
|
||||
|
||||
```python
|
||||
retriever = NumpyRetriever.from_texts(texts= ["hello world", "goodbye world"])
|
||||
```
|
||||
|
||||
You can then use the retriever to get relevant documents.
|
||||
|
||||
```python
|
||||
retriever.get_relevant_documents("Hi there!")
|
||||
|
||||
# [Document(page_content='hello world', metadata={'index': 0})]
|
||||
```
|
||||
|
||||
```python
|
||||
retriever.get_relevant_documents("Bye!")
|
||||
# [Document(page_content='goodbye world', metadata={'index': 1})]
|
||||
```
|
@ -0,0 +1,35 @@
|
||||
"""A retriever that uses PubMed API to retrieve documents."""
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||
|
||||
|
||||
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
"""
|
||||
It is effectively a wrapper for PubMedAPIWrapper.
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all PubMedAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
@ -1,18 +1,5 @@
|
||||
from typing import List
|
||||
from langchain.retrievers.pubmed import PubMedRetriever
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||
|
||||
|
||||
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
"""
|
||||
It is effectively a wrapper for PubMedAPIWrapper.
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all PubMedAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
return self.load_docs(query=query)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
__all__ = [
|
||||
"PubMedRetriever",
|
||||
]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,220 @@
|
||||
"""Test Base Retriever logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v1() -> BaseRetriever:
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Retrievers must implement abstract "
|
||||
"`_get_relevant_documents` method instead of `get_relevant_documents`",
|
||||
):
|
||||
|
||||
class FakeRetrieverV1(BaseRetriever):
|
||||
def get_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(page_content=query, metadata={"uuid": "1234"}),
|
||||
]
|
||||
|
||||
async def aget_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Async query {query}", metadata={"uuid": "1234"}
|
||||
),
|
||||
]
|
||||
|
||||
return FakeRetrieverV1() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v1_upgrade(fake_retriever_v1: BaseRetriever) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1._new_arg_supported is False
|
||||
assert fake_retriever_v1._expects_other_args is False
|
||||
results: List[Document] = fake_retriever_v1.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks]
|
||||
)
|
||||
assert results[0].page_content == "Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_retriever_v1_upgrade_async(
|
||||
fake_retriever_v1: BaseRetriever,
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1._new_arg_supported is False
|
||||
assert fake_retriever_v1._expects_other_args is False
|
||||
results: List[Document] = await fake_retriever_v1.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks]
|
||||
)
|
||||
assert results[0].page_content == "Async query Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v1_with_kwargs() -> BaseRetriever:
|
||||
# Test for things like the Weaviate V1 Retriever.
|
||||
with pytest.warns(
|
||||
DeprecationWarning,
|
||||
match="Retrievers must implement abstract "
|
||||
"`_get_relevant_documents` method instead of `get_relevant_documents`",
|
||||
):
|
||||
|
||||
class FakeRetrieverV1(BaseRetriever):
|
||||
def get_relevant_documents( # type: ignore[override]
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(page_content=query, metadata=where_filter or {}),
|
||||
]
|
||||
|
||||
async def aget_relevant_documents( # type: ignore[override]
|
||||
self, query: str, where_filter: Optional[Dict[str, object]] = None
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV1)
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Async query {query}", metadata=where_filter or {}
|
||||
),
|
||||
]
|
||||
|
||||
return FakeRetrieverV1() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v1_with_kwargs_upgrade(
|
||||
fake_retriever_v1_with_kwargs: BaseRetriever,
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1_with_kwargs._new_arg_supported is False
|
||||
assert fake_retriever_v1_with_kwargs._expects_other_args is True
|
||||
results: List[Document] = fake_retriever_v1_with_kwargs.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], where_filter={"foo": "bar"}
|
||||
)
|
||||
assert results[0].page_content == "Foo"
|
||||
assert results[0].metadata == {"foo": "bar"}
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_retriever_v1_with_kwargs_upgrade_async(
|
||||
fake_retriever_v1_with_kwargs: BaseRetriever,
|
||||
) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v1_with_kwargs._new_arg_supported is False
|
||||
assert fake_retriever_v1_with_kwargs._expects_other_args is True
|
||||
results: List[
|
||||
Document
|
||||
] = await fake_retriever_v1_with_kwargs.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], where_filter={"foo": "bar"}
|
||||
)
|
||||
assert results[0].page_content == "Async query Foo"
|
||||
assert results[0].metadata == {"foo": "bar"}
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_retriever_v2() -> BaseRetriever:
|
||||
class FakeRetrieverV2(BaseRetriever):
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV2)
|
||||
assert run_manager is not None
|
||||
assert isinstance(run_manager, CallbackManagerForRetrieverRun)
|
||||
if "throw_error" in kwargs:
|
||||
raise ValueError("Test error")
|
||||
return [
|
||||
Document(page_content=query, metadata=kwargs),
|
||||
]
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
assert isinstance(self, FakeRetrieverV2)
|
||||
assert run_manager is not None
|
||||
assert isinstance(run_manager, AsyncCallbackManagerForRetrieverRun)
|
||||
if "throw_error" in kwargs:
|
||||
raise ValueError("Test error")
|
||||
return [
|
||||
Document(page_content=f"Async query {query}", metadata=kwargs),
|
||||
]
|
||||
|
||||
return FakeRetrieverV2() # type: ignore[abstract]
|
||||
|
||||
|
||||
def test_fake_retriever_v2(fake_retriever_v2: BaseRetriever) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v2._new_arg_supported is True
|
||||
results = fake_retriever_v2.get_relevant_documents("Foo", callbacks=[callbacks])
|
||||
assert results[0].page_content == "Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
results2 = fake_retriever_v2.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], foo="bar"
|
||||
)
|
||||
assert results2[0].metadata == {"foo": "bar"}
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
fake_retriever_v2.get_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], throw_error=True
|
||||
)
|
||||
assert callbacks.retriever_errors == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fake_retriever_v2_async(fake_retriever_v2: BaseRetriever) -> None:
|
||||
callbacks = FakeCallbackHandler()
|
||||
assert fake_retriever_v2._new_arg_supported is True
|
||||
results = await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks]
|
||||
)
|
||||
assert results[0].page_content == "Async query Foo"
|
||||
assert callbacks.retriever_starts == 1
|
||||
assert callbacks.retriever_ends == 1
|
||||
assert callbacks.retriever_errors == 0
|
||||
results2 = await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], foo="bar"
|
||||
)
|
||||
assert results2[0].metadata == {"foo": "bar"}
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await fake_retriever_v2.aget_relevant_documents(
|
||||
"Foo", callbacks=[callbacks], throw_error=True
|
||||
)
|
Loading…
Reference in New Issue