mirror of https://github.com/hwchase17/langchain
Zep Retriever - Vector Search Over Chat History (#4533)
# Zep Retriever - Vector Search Over Chat History with the Zep Long-term Memory Service More on Zep: https://github.com/getzep/zep Note: This PR is related to and relies on https://github.com/hwchase17/langchain/pull/4834. I did not want to modify the `pyproject.toml` file to add the `zep-python` dependency a second time. Co-authored-by: Daniel Chalef <daniel.chalef@private.org>pull/2675/head^2
parent
5525b704cc
commit
c8c2276ccb
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python import SearchResult
|
||||
|
||||
|
||||
class ZepRetriever(BaseRetriever):
|
||||
"""A Retriever implementation for the Zep long-term memory store. Search your
|
||||
user's long-term chat history with Zep.
|
||||
|
||||
Note: You will need to provide the user's `session_id` to use this retriever.
|
||||
|
||||
More on Zep:
|
||||
Zep provides long-term conversation storage for LLM apps. The server stores,
|
||||
summarizes, embeds, indexes, and enriches conversational AI chat
|
||||
histories, and exposes them via simple, low-latency APIs.
|
||||
|
||||
For server installation instructions, see:
|
||||
https://getzep.github.io/deployment/quickstart/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
url: str,
|
||||
top_k: Optional[int] = None,
|
||||
):
|
||||
try:
|
||||
from zep_python import ZepClient
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import zep-python package. "
|
||||
"Please install it with `pip install zep-python`."
|
||||
)
|
||||
|
||||
self.zep_client = ZepClient(base_url=url)
|
||||
self.session_id = session_id
|
||||
self.top_k = top_k
|
||||
|
||||
def _search_result_to_doc(self, results: List[SearchResult]) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=r.message.pop("content"),
|
||||
metadata={"score": r.dist, **r.message},
|
||||
)
|
||||
for r in results
|
||||
if r.message
|
||||
]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
from zep_python import SearchPayload
|
||||
|
||||
payload: SearchPayload = SearchPayload(text=query)
|
||||
|
||||
results: List[SearchResult] = self.zep_client.search_memory(
|
||||
self.session_id, payload, limit=self.top_k
|
||||
)
|
||||
|
||||
return self._search_result_to_doc(results)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
from zep_python import SearchPayload
|
||||
|
||||
payload: SearchPayload = SearchPayload(text=query)
|
||||
|
||||
results: List[SearchResult] = await self.zep_client.asearch_memory(
|
||||
self.session_id, payload, limit=self.top_k
|
||||
)
|
||||
|
||||
return self._search_result_to_doc(results)
|
@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.retrievers import ZepRetriever
|
||||
from langchain.schema import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python import SearchResult, ZepClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_results() -> List[SearchResult]:
|
||||
from zep_python import Message, SearchResult
|
||||
|
||||
search_result = [
|
||||
{
|
||||
"message": {
|
||||
"uuid": "66830914-19f5-490b-8677-1ba06bcd556b",
|
||||
"created_at": "2023-05-18T20:40:42.743773Z",
|
||||
"role": "user",
|
||||
"content": "I'm looking to plan a trip to Iceland. Can you help me?",
|
||||
"token_count": 17,
|
||||
},
|
||||
"summary": None,
|
||||
"dist": 0.8734284910450115,
|
||||
},
|
||||
{
|
||||
"message": {
|
||||
"uuid": "015e618c-ba9d-45b6-95c3-77a8e611570b",
|
||||
"created_at": "2023-05-18T20:40:42.743773Z",
|
||||
"role": "user",
|
||||
"content": "How much does a trip to Iceland typically cost?",
|
||||
"token_count": 12,
|
||||
},
|
||||
"summary": None,
|
||||
"dist": 0.8554048017463456,
|
||||
},
|
||||
]
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
message=Message.parse_obj(result["message"]),
|
||||
summary=result["summary"],
|
||||
dist=result["dist"],
|
||||
)
|
||||
for result in search_result
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.requires("zep_python")
|
||||
def zep_retriever(
|
||||
mocker: MockerFixture, search_results: List[SearchResult]
|
||||
) -> ZepRetriever:
|
||||
mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True)
|
||||
mock_zep_client.search_memory.return_value = copy.deepcopy( # type: ignore
|
||||
search_results
|
||||
)
|
||||
mock_zep_client.asearch_memory.return_value = copy.deepcopy( # type: ignore
|
||||
search_results
|
||||
)
|
||||
zep = ZepRetriever(session_id="123", url="http://localhost:8000")
|
||||
zep.zep_client = mock_zep_client
|
||||
return zep
|
||||
|
||||
|
||||
@pytest.mark.requires("zep_python")
|
||||
def test_zep_retriever_get_relevant_documents(
|
||||
zep_retriever: ZepRetriever, search_results: List[SearchResult]
|
||||
) -> None:
|
||||
documents: List[Document] = zep_retriever.get_relevant_documents(
|
||||
query="My trip to Iceland"
|
||||
)
|
||||
_test_documents(documents, search_results)
|
||||
|
||||
|
||||
@pytest.mark.requires("zep_python")
|
||||
@pytest.mark.asyncio
|
||||
async def test_zep_retriever_aget_relevant_documents(
|
||||
zep_retriever: ZepRetriever, search_results: List[SearchResult]
|
||||
) -> None:
|
||||
documents: List[Document] = await zep_retriever.aget_relevant_documents(
|
||||
query="My trip to Iceland"
|
||||
)
|
||||
_test_documents(documents, search_results)
|
||||
|
||||
|
||||
def _test_documents(
|
||||
documents: List[Document], search_results: List[SearchResult]
|
||||
) -> None:
|
||||
assert len(documents) == 2
|
||||
for i, document in enumerate(documents):
|
||||
assert document.page_content == search_results[i].message.get( # type: ignore
|
||||
"content"
|
||||
)
|
||||
assert document.metadata.get("uuid") == search_results[
|
||||
i
|
||||
].message.get( # type: ignore
|
||||
"uuid"
|
||||
)
|
||||
assert document.metadata.get("role") == search_results[
|
||||
i
|
||||
].message.get( # type: ignore
|
||||
"role"
|
||||
)
|
||||
assert document.metadata.get("score") == search_results[i].dist
|
Loading…
Reference in New Issue