Zep Hybrid Search (#5742)

Zep now supports persisting custom metadata with messages and hybrid
search across both message embeddings and structured metadata. This PR
implements custom metadata and enhancements to the
`ZepChatMessageHistory` and `ZepRetriever` classes to implement this
support.

Tag maintainers/contributors who might be interested:

  VectorStores / Retrievers / Memory
  - @dev2049

---------

Co-authored-by: Daniel Chalef <daniel.chalef@private.org>
This commit is contained in:
Daniel Chalef 2023-06-05 12:59:28 -07:00 committed by GitHub
parent a0ea6f6b6b
commit 0551bc90a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 493 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
@ -11,7 +11,7 @@ from langchain.schema import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from zep_python import Memory, Message, NotFoundError, SearchResult from zep_python import Memory, MemorySearchResult, Message, NotFoundError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -130,11 +130,15 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
self.zep_client.add_memory(self.session_id, zep_memory) self.zep_client.add_memory(self.session_id, zep_memory)
def search(self, query: str, limit: Optional[int] = None) -> List[SearchResult]: def search(
self, query: str, metadata: Optional[Dict] = None, limit: Optional[int] = None
) -> List[MemorySearchResult]:
"""Search Zep memory for messages matching the query""" """Search Zep memory for messages matching the query"""
from zep_python import SearchPayload from zep_python import MemorySearchPayload
payload: SearchPayload = SearchPayload(text=query) payload: MemorySearchPayload = MemorySearchPayload(
text=query, metadata=metadata
)
return self.zep_client.search_memory(self.session_id, payload, limit=limit) return self.zep_client.search_memory(self.session_id, payload, limit=limit)

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.schema import BaseRetriever, Document from langchain.schema import BaseRetriever, Document
if TYPE_CHECKING: if TYPE_CHECKING:
from zep_python import SearchResult from zep_python import MemorySearchResult
class ZepRetriever(BaseRetriever): class ZepRetriever(BaseRetriever):
@ -41,7 +41,9 @@ class ZepRetriever(BaseRetriever):
self.session_id = session_id self.session_id = session_id
self.top_k = top_k self.top_k = top_k
def _search_result_to_doc(self, results: List[SearchResult]) -> List[Document]: def _search_result_to_doc(
self, results: List[MemorySearchResult]
) -> List[Document]:
return [ return [
Document( Document(
page_content=r.message.pop("content"), page_content=r.message.pop("content"),
@ -51,23 +53,31 @@ class ZepRetriever(BaseRetriever):
if r.message if r.message
] ]
def get_relevant_documents(self, query: str) -> List[Document]: def get_relevant_documents(
from zep_python import SearchPayload self, query: str, metadata: Optional[Dict] = None
) -> List[Document]:
from zep_python import MemorySearchPayload
payload: SearchPayload = SearchPayload(text=query) payload: MemorySearchPayload = MemorySearchPayload(
text=query, metadata=metadata
)
results: List[SearchResult] = self.zep_client.search_memory( results: List[MemorySearchResult] = self.zep_client.search_memory(
self.session_id, payload, limit=self.top_k self.session_id, payload, limit=self.top_k
) )
return self._search_result_to_doc(results) return self._search_result_to_doc(results)
async def aget_relevant_documents(self, query: str) -> List[Document]: async def aget_relevant_documents(
from zep_python import SearchPayload self, query: str, metadata: Optional[Dict] = None
) -> List[Document]:
from zep_python import MemorySearchPayload
payload: SearchPayload = SearchPayload(text=query) payload: MemorySearchPayload = MemorySearchPayload(
text=query, metadata=metadata
)
results: List[SearchResult] = await self.zep_client.asearch_memory( results: List[MemorySearchResult] = await self.zep_client.asearch_memory(
self.session_id, payload, limit=self.top_k self.session_id, payload, limit=self.top_k
) )

488
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -90,7 +90,7 @@ pandas = {version = "^2.0.1", optional = true}
telethon = {version = "^1.28.5", optional = true} telethon = {version = "^1.28.5", optional = true}
neo4j = {version = "^5.8.1", optional = true} neo4j = {version = "^5.8.1", optional = true}
psychicapi = {version = "^0.5", optional = true} psychicapi = {version = "^0.5", optional = true}
zep-python = {version="^0.30", optional=true} zep-python = {version=">=0.31", optional=true}
langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true} langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true}
chardet = {version="^5.1.0", optional=true} chardet = {version="^5.1.0", optional=true}
requests-toolbelt = {version = "^1.0.0", optional = true} requests-toolbelt = {version = "^1.0.0", optional = true}

View File

@ -10,12 +10,12 @@ from langchain.retrievers import ZepRetriever
from langchain.schema import Document from langchain.schema import Document
if TYPE_CHECKING: if TYPE_CHECKING:
from zep_python import SearchResult, ZepClient from zep_python import MemorySearchResult, ZepClient
@pytest.fixture @pytest.fixture
def search_results() -> List[SearchResult]: def search_results() -> List[MemorySearchResult]:
from zep_python import Message, SearchResult from zep_python import MemorySearchResult, Message
search_result = [ search_result = [
{ {
@ -43,7 +43,7 @@ def search_results() -> List[SearchResult]:
] ]
return [ return [
SearchResult( MemorySearchResult(
message=Message.parse_obj(result["message"]), message=Message.parse_obj(result["message"]),
summary=result["summary"], summary=result["summary"],
dist=result["dist"], dist=result["dist"],
@ -55,7 +55,7 @@ def search_results() -> List[SearchResult]:
@pytest.fixture @pytest.fixture
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
def zep_retriever( def zep_retriever(
mocker: MockerFixture, search_results: List[SearchResult] mocker: MockerFixture, search_results: List[MemorySearchResult]
) -> ZepRetriever: ) -> ZepRetriever:
mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True) mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True)
mock_zep_client.search_memory.return_value = copy.deepcopy( # type: ignore mock_zep_client.search_memory.return_value = copy.deepcopy( # type: ignore
@ -71,7 +71,7 @@ def zep_retriever(
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
def test_zep_retriever_get_relevant_documents( def test_zep_retriever_get_relevant_documents(
zep_retriever: ZepRetriever, search_results: List[SearchResult] zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
) -> None: ) -> None:
documents: List[Document] = zep_retriever.get_relevant_documents( documents: List[Document] = zep_retriever.get_relevant_documents(
query="My trip to Iceland" query="My trip to Iceland"
@ -82,7 +82,7 @@ def test_zep_retriever_get_relevant_documents(
@pytest.mark.requires("zep_python") @pytest.mark.requires("zep_python")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_zep_retriever_aget_relevant_documents( async def test_zep_retriever_aget_relevant_documents(
zep_retriever: ZepRetriever, search_results: List[SearchResult] zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
) -> None: ) -> None:
documents: List[Document] = await zep_retriever.aget_relevant_documents( documents: List[Document] = await zep_retriever.aget_relevant_documents(
query="My trip to Iceland" query="My trip to Iceland"
@ -91,7 +91,7 @@ async def test_zep_retriever_aget_relevant_documents(
def _test_documents( def _test_documents(
documents: List[Document], search_results: List[SearchResult] documents: List[Document], search_results: List[MemorySearchResult]
) -> None: ) -> None:
assert len(documents) == 2 assert len(documents) == 2
for i, document in enumerate(documents): for i, document in enumerate(documents):