forked from Archives/langchain
Zep Hybrid Search (#5742)
Zep now supports persisting custom metadata with messages and hybrid search across both message embeddings and structured metadata. This PR implements custom metadata and enhancements to the `ZepChatMessageHistory` and `ZepRetriever` classes to implement this support. Tag maintainers/contributors who might be interested: VectorStores / Retrievers / Memory - @dev2049 --------- Co-authored-by: Daniel Chalef <daniel.chalef@private.org>
This commit is contained in:
parent
a0ea6f6b6b
commit
0551bc90a5
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
@ -11,7 +11,7 @@ from langchain.schema import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python import Memory, Message, NotFoundError, SearchResult
|
||||
from zep_python import Memory, MemorySearchResult, Message, NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -130,11 +130,15 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
self.zep_client.add_memory(self.session_id, zep_memory)
|
||||
|
||||
def search(self, query: str, limit: Optional[int] = None) -> List[SearchResult]:
|
||||
def search(
|
||||
self, query: str, metadata: Optional[Dict] = None, limit: Optional[int] = None
|
||||
) -> List[MemorySearchResult]:
|
||||
"""Search Zep memory for messages matching the query"""
|
||||
from zep_python import SearchPayload
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
payload: SearchPayload = SearchPayload(text=query)
|
||||
payload: MemorySearchPayload = MemorySearchPayload(
|
||||
text=query, metadata=metadata
|
||||
)
|
||||
|
||||
return self.zep_client.search_memory(self.session_id, payload, limit=limit)
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python import SearchResult
|
||||
from zep_python import MemorySearchResult
|
||||
|
||||
|
||||
class ZepRetriever(BaseRetriever):
|
||||
@ -41,7 +41,9 @@ class ZepRetriever(BaseRetriever):
|
||||
self.session_id = session_id
|
||||
self.top_k = top_k
|
||||
|
||||
def _search_result_to_doc(self, results: List[SearchResult]) -> List[Document]:
|
||||
def _search_result_to_doc(
|
||||
self, results: List[MemorySearchResult]
|
||||
) -> List[Document]:
|
||||
return [
|
||||
Document(
|
||||
page_content=r.message.pop("content"),
|
||||
@ -51,23 +53,31 @@ class ZepRetriever(BaseRetriever):
|
||||
if r.message
|
||||
]
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
from zep_python import SearchPayload
|
||||
def get_relevant_documents(
|
||||
self, query: str, metadata: Optional[Dict] = None
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
payload: SearchPayload = SearchPayload(text=query)
|
||||
payload: MemorySearchPayload = MemorySearchPayload(
|
||||
text=query, metadata=metadata
|
||||
)
|
||||
|
||||
results: List[SearchResult] = self.zep_client.search_memory(
|
||||
results: List[MemorySearchResult] = self.zep_client.search_memory(
|
||||
self.session_id, payload, limit=self.top_k
|
||||
)
|
||||
|
||||
return self._search_result_to_doc(results)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
from zep_python import SearchPayload
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, metadata: Optional[Dict] = None
|
||||
) -> List[Document]:
|
||||
from zep_python import MemorySearchPayload
|
||||
|
||||
payload: SearchPayload = SearchPayload(text=query)
|
||||
payload: MemorySearchPayload = MemorySearchPayload(
|
||||
text=query, metadata=metadata
|
||||
)
|
||||
|
||||
results: List[SearchResult] = await self.zep_client.asearch_memory(
|
||||
results: List[MemorySearchResult] = await self.zep_client.asearch_memory(
|
||||
self.session_id, payload, limit=self.top_k
|
||||
)
|
||||
|
||||
|
488
poetry.lock
generated
488
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -90,7 +90,7 @@ pandas = {version = "^2.0.1", optional = true}
|
||||
telethon = {version = "^1.28.5", optional = true}
|
||||
neo4j = {version = "^5.8.1", optional = true}
|
||||
psychicapi = {version = "^0.5", optional = true}
|
||||
zep-python = {version="^0.30", optional=true}
|
||||
zep-python = {version=">=0.31", optional=true}
|
||||
langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true}
|
||||
chardet = {version="^5.1.0", optional=true}
|
||||
requests-toolbelt = {version = "^1.0.0", optional = true}
|
||||
|
@ -10,12 +10,12 @@ from langchain.retrievers import ZepRetriever
|
||||
from langchain.schema import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python import SearchResult, ZepClient
|
||||
from zep_python import MemorySearchResult, ZepClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_results() -> List[SearchResult]:
|
||||
from zep_python import Message, SearchResult
|
||||
def search_results() -> List[MemorySearchResult]:
|
||||
from zep_python import MemorySearchResult, Message
|
||||
|
||||
search_result = [
|
||||
{
|
||||
@ -43,7 +43,7 @@ def search_results() -> List[SearchResult]:
|
||||
]
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
MemorySearchResult(
|
||||
message=Message.parse_obj(result["message"]),
|
||||
summary=result["summary"],
|
||||
dist=result["dist"],
|
||||
@ -55,7 +55,7 @@ def search_results() -> List[SearchResult]:
|
||||
@pytest.fixture
|
||||
@pytest.mark.requires("zep_python")
|
||||
def zep_retriever(
|
||||
mocker: MockerFixture, search_results: List[SearchResult]
|
||||
mocker: MockerFixture, search_results: List[MemorySearchResult]
|
||||
) -> ZepRetriever:
|
||||
mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True)
|
||||
mock_zep_client.search_memory.return_value = copy.deepcopy( # type: ignore
|
||||
@ -71,7 +71,7 @@ def zep_retriever(
|
||||
|
||||
@pytest.mark.requires("zep_python")
|
||||
def test_zep_retriever_get_relevant_documents(
|
||||
zep_retriever: ZepRetriever, search_results: List[SearchResult]
|
||||
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
|
||||
) -> None:
|
||||
documents: List[Document] = zep_retriever.get_relevant_documents(
|
||||
query="My trip to Iceland"
|
||||
@ -82,7 +82,7 @@ def test_zep_retriever_get_relevant_documents(
|
||||
@pytest.mark.requires("zep_python")
|
||||
@pytest.mark.asyncio
|
||||
async def test_zep_retriever_aget_relevant_documents(
|
||||
zep_retriever: ZepRetriever, search_results: List[SearchResult]
|
||||
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
|
||||
) -> None:
|
||||
documents: List[Document] = await zep_retriever.aget_relevant_documents(
|
||||
query="My trip to Iceland"
|
||||
@ -91,7 +91,7 @@ async def test_zep_retriever_aget_relevant_documents(
|
||||
|
||||
|
||||
def _test_documents(
|
||||
documents: List[Document], search_results: List[SearchResult]
|
||||
documents: List[Document], search_results: List[MemorySearchResult]
|
||||
) -> None:
|
||||
assert len(documents) == 2
|
||||
for i, document in enumerate(documents):
|
||||
|
Loading…
Reference in New Issue
Block a user