Zep Hybrid Search (#5742)

Zep now supports persisting custom metadata with messages and hybrid
search across both message embeddings and structured metadata. This PR
implements custom metadata and enhancements to the
`ZepChatMessageHistory` and `ZepRetriever` classes to implement this
support.

Tag maintainers/contributors who might be interested:

  VectorStores / Retrievers / Memory
  - @dev2049

---------

Co-authored-by: Daniel Chalef <daniel.chalef@private.org>
searx_updates
Daniel Chalef 12 months ago committed by GitHub
parent a0ea6f6b6b
commit 0551bc90a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.schema import (
AIMessage,
@ -11,7 +11,7 @@ from langchain.schema import (
)
if TYPE_CHECKING:
from zep_python import Memory, Message, NotFoundError, SearchResult
from zep_python import Memory, MemorySearchResult, Message, NotFoundError
logger = logging.getLogger(__name__)
@ -130,11 +130,15 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
self.zep_client.add_memory(self.session_id, zep_memory)
def search(self, query: str, limit: Optional[int] = None) -> List[SearchResult]:
def search(
self, query: str, metadata: Optional[Dict] = None, limit: Optional[int] = None
) -> List[MemorySearchResult]:
"""Search Zep memory for messages matching the query"""
from zep_python import SearchPayload
from zep_python import MemorySearchPayload
payload: SearchPayload = SearchPayload(text=query)
payload: MemorySearchPayload = MemorySearchPayload(
text=query, metadata=metadata
)
return self.zep_client.search_memory(self.session_id, payload, limit=limit)

@ -1,11 +1,11 @@
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.schema import BaseRetriever, Document
if TYPE_CHECKING:
from zep_python import SearchResult
from zep_python import MemorySearchResult
class ZepRetriever(BaseRetriever):
@ -41,7 +41,9 @@ class ZepRetriever(BaseRetriever):
self.session_id = session_id
self.top_k = top_k
def _search_result_to_doc(self, results: List[SearchResult]) -> List[Document]:
def _search_result_to_doc(
self, results: List[MemorySearchResult]
) -> List[Document]:
return [
Document(
page_content=r.message.pop("content"),
@ -51,23 +53,31 @@ class ZepRetriever(BaseRetriever):
if r.message
]
def get_relevant_documents(self, query: str) -> List[Document]:
from zep_python import SearchPayload
def get_relevant_documents(
self, query: str, metadata: Optional[Dict] = None
) -> List[Document]:
from zep_python import MemorySearchPayload
payload: SearchPayload = SearchPayload(text=query)
payload: MemorySearchPayload = MemorySearchPayload(
text=query, metadata=metadata
)
results: List[SearchResult] = self.zep_client.search_memory(
results: List[MemorySearchResult] = self.zep_client.search_memory(
self.session_id, payload, limit=self.top_k
)
return self._search_result_to_doc(results)
async def aget_relevant_documents(self, query: str) -> List[Document]:
from zep_python import SearchPayload
async def aget_relevant_documents(
self, query: str, metadata: Optional[Dict] = None
) -> List[Document]:
from zep_python import MemorySearchPayload
payload: SearchPayload = SearchPayload(text=query)
payload: MemorySearchPayload = MemorySearchPayload(
text=query, metadata=metadata
)
results: List[SearchResult] = await self.zep_client.asearch_memory(
results: List[MemorySearchResult] = await self.zep_client.asearch_memory(
self.session_id, payload, limit=self.top_k
)

488
poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -90,7 +90,7 @@ pandas = {version = "^2.0.1", optional = true}
telethon = {version = "^1.28.5", optional = true}
neo4j = {version = "^5.8.1", optional = true}
psychicapi = {version = "^0.5", optional = true}
zep-python = {version="^0.30", optional=true}
zep-python = {version=">=0.31", optional=true}
langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true}
chardet = {version="^5.1.0", optional=true}
requests-toolbelt = {version = "^1.0.0", optional = true}

@ -10,12 +10,12 @@ from langchain.retrievers import ZepRetriever
from langchain.schema import Document
if TYPE_CHECKING:
from zep_python import SearchResult, ZepClient
from zep_python import MemorySearchResult, ZepClient
@pytest.fixture
def search_results() -> List[SearchResult]:
from zep_python import Message, SearchResult
def search_results() -> List[MemorySearchResult]:
from zep_python import MemorySearchResult, Message
search_result = [
{
@ -43,7 +43,7 @@ def search_results() -> List[SearchResult]:
]
return [
SearchResult(
MemorySearchResult(
message=Message.parse_obj(result["message"]),
summary=result["summary"],
dist=result["dist"],
@ -55,7 +55,7 @@ def search_results() -> List[SearchResult]:
@pytest.fixture
@pytest.mark.requires("zep_python")
def zep_retriever(
mocker: MockerFixture, search_results: List[SearchResult]
mocker: MockerFixture, search_results: List[MemorySearchResult]
) -> ZepRetriever:
mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True)
mock_zep_client.search_memory.return_value = copy.deepcopy( # type: ignore
@ -71,7 +71,7 @@ def zep_retriever(
@pytest.mark.requires("zep_python")
def test_zep_retriever_get_relevant_documents(
zep_retriever: ZepRetriever, search_results: List[SearchResult]
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
) -> None:
documents: List[Document] = zep_retriever.get_relevant_documents(
query="My trip to Iceland"
@ -82,7 +82,7 @@ def test_zep_retriever_get_relevant_documents(
@pytest.mark.requires("zep_python")
@pytest.mark.asyncio
async def test_zep_retriever_aget_relevant_documents(
zep_retriever: ZepRetriever, search_results: List[SearchResult]
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
) -> None:
documents: List[Document] = await zep_retriever.aget_relevant_documents(
query="My trip to Iceland"
@ -91,7 +91,7 @@ async def test_zep_retriever_aget_relevant_documents(
def _test_documents(
documents: List[Document], search_results: List[SearchResult]
documents: List[Document], search_results: List[MemorySearchResult]
) -> None:
assert len(documents) == 2
for i, document in enumerate(documents):

Loading…
Cancel
Save