forked from Archives/langchain
Zep Hybrid Search (#5742)
Zep now supports persisting custom metadata with messages and hybrid search across both message embeddings and structured metadata. This PR implements custom metadata and enhancements to the `ZepChatMessageHistory` and `ZepRetriever` classes to implement this support. Tag maintainers/contributors who might be interested: VectorStores / Retrievers / Memory - @dev2049 --------- Co-authored-by: Daniel Chalef <daniel.chalef@private.org>
This commit is contained in:
parent
a0ea6f6b6b
commit
0551bc90a5
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -11,7 +11,7 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from zep_python import Memory, Message, NotFoundError, SearchResult
|
from zep_python import Memory, MemorySearchResult, Message, NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -130,11 +130,15 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
|
|||||||
|
|
||||||
self.zep_client.add_memory(self.session_id, zep_memory)
|
self.zep_client.add_memory(self.session_id, zep_memory)
|
||||||
|
|
||||||
def search(self, query: str, limit: Optional[int] = None) -> List[SearchResult]:
|
def search(
|
||||||
|
self, query: str, metadata: Optional[Dict] = None, limit: Optional[int] = None
|
||||||
|
) -> List[MemorySearchResult]:
|
||||||
"""Search Zep memory for messages matching the query"""
|
"""Search Zep memory for messages matching the query"""
|
||||||
from zep_python import SearchPayload
|
from zep_python import MemorySearchPayload
|
||||||
|
|
||||||
payload: SearchPayload = SearchPayload(text=query)
|
payload: MemorySearchPayload = MemorySearchPayload(
|
||||||
|
text=query, metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
return self.zep_client.search_memory(self.session_id, payload, limit=limit)
|
return self.zep_client.search_memory(self.session_id, payload, limit=limit)
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.schema import BaseRetriever, Document
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from zep_python import SearchResult
|
from zep_python import MemorySearchResult
|
||||||
|
|
||||||
|
|
||||||
class ZepRetriever(BaseRetriever):
|
class ZepRetriever(BaseRetriever):
|
||||||
@ -41,7 +41,9 @@ class ZepRetriever(BaseRetriever):
|
|||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
def _search_result_to_doc(self, results: List[SearchResult]) -> List[Document]:
|
def _search_result_to_doc(
|
||||||
|
self, results: List[MemorySearchResult]
|
||||||
|
) -> List[Document]:
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=r.message.pop("content"),
|
page_content=r.message.pop("content"),
|
||||||
@ -51,23 +53,31 @@ class ZepRetriever(BaseRetriever):
|
|||||||
if r.message
|
if r.message
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
def get_relevant_documents(
|
||||||
from zep_python import SearchPayload
|
self, query: str, metadata: Optional[Dict] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
from zep_python import MemorySearchPayload
|
||||||
|
|
||||||
payload: SearchPayload = SearchPayload(text=query)
|
payload: MemorySearchPayload = MemorySearchPayload(
|
||||||
|
text=query, metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
results: List[SearchResult] = self.zep_client.search_memory(
|
results: List[MemorySearchResult] = self.zep_client.search_memory(
|
||||||
self.session_id, payload, limit=self.top_k
|
self.session_id, payload, limit=self.top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._search_result_to_doc(results)
|
return self._search_result_to_doc(results)
|
||||||
|
|
||||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
async def aget_relevant_documents(
|
||||||
from zep_python import SearchPayload
|
self, query: str, metadata: Optional[Dict] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
from zep_python import MemorySearchPayload
|
||||||
|
|
||||||
payload: SearchPayload = SearchPayload(text=query)
|
payload: MemorySearchPayload = MemorySearchPayload(
|
||||||
|
text=query, metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
results: List[SearchResult] = await self.zep_client.asearch_memory(
|
results: List[MemorySearchResult] = await self.zep_client.asearch_memory(
|
||||||
self.session_id, payload, limit=self.top_k
|
self.session_id, payload, limit=self.top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
488
poetry.lock
generated
488
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -90,7 +90,7 @@ pandas = {version = "^2.0.1", optional = true}
|
|||||||
telethon = {version = "^1.28.5", optional = true}
|
telethon = {version = "^1.28.5", optional = true}
|
||||||
neo4j = {version = "^5.8.1", optional = true}
|
neo4j = {version = "^5.8.1", optional = true}
|
||||||
psychicapi = {version = "^0.5", optional = true}
|
psychicapi = {version = "^0.5", optional = true}
|
||||||
zep-python = {version="^0.30", optional=true}
|
zep-python = {version=">=0.31", optional=true}
|
||||||
langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true}
|
langkit = {version = ">=0.0.1.dev3, <0.1.0", optional = true}
|
||||||
chardet = {version="^5.1.0", optional=true}
|
chardet = {version="^5.1.0", optional=true}
|
||||||
requests-toolbelt = {version = "^1.0.0", optional = true}
|
requests-toolbelt = {version = "^1.0.0", optional = true}
|
||||||
|
@ -10,12 +10,12 @@ from langchain.retrievers import ZepRetriever
|
|||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from zep_python import SearchResult, ZepClient
|
from zep_python import MemorySearchResult, ZepClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def search_results() -> List[SearchResult]:
|
def search_results() -> List[MemorySearchResult]:
|
||||||
from zep_python import Message, SearchResult
|
from zep_python import MemorySearchResult, Message
|
||||||
|
|
||||||
search_result = [
|
search_result = [
|
||||||
{
|
{
|
||||||
@ -43,7 +43,7 @@ def search_results() -> List[SearchResult]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
SearchResult(
|
MemorySearchResult(
|
||||||
message=Message.parse_obj(result["message"]),
|
message=Message.parse_obj(result["message"]),
|
||||||
summary=result["summary"],
|
summary=result["summary"],
|
||||||
dist=result["dist"],
|
dist=result["dist"],
|
||||||
@ -55,7 +55,7 @@ def search_results() -> List[SearchResult]:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@pytest.mark.requires("zep_python")
|
@pytest.mark.requires("zep_python")
|
||||||
def zep_retriever(
|
def zep_retriever(
|
||||||
mocker: MockerFixture, search_results: List[SearchResult]
|
mocker: MockerFixture, search_results: List[MemorySearchResult]
|
||||||
) -> ZepRetriever:
|
) -> ZepRetriever:
|
||||||
mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True)
|
mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True)
|
||||||
mock_zep_client.search_memory.return_value = copy.deepcopy( # type: ignore
|
mock_zep_client.search_memory.return_value = copy.deepcopy( # type: ignore
|
||||||
@ -71,7 +71,7 @@ def zep_retriever(
|
|||||||
|
|
||||||
@pytest.mark.requires("zep_python")
|
@pytest.mark.requires("zep_python")
|
||||||
def test_zep_retriever_get_relevant_documents(
|
def test_zep_retriever_get_relevant_documents(
|
||||||
zep_retriever: ZepRetriever, search_results: List[SearchResult]
|
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
|
||||||
) -> None:
|
) -> None:
|
||||||
documents: List[Document] = zep_retriever.get_relevant_documents(
|
documents: List[Document] = zep_retriever.get_relevant_documents(
|
||||||
query="My trip to Iceland"
|
query="My trip to Iceland"
|
||||||
@ -82,7 +82,7 @@ def test_zep_retriever_get_relevant_documents(
|
|||||||
@pytest.mark.requires("zep_python")
|
@pytest.mark.requires("zep_python")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_zep_retriever_aget_relevant_documents(
|
async def test_zep_retriever_aget_relevant_documents(
|
||||||
zep_retriever: ZepRetriever, search_results: List[SearchResult]
|
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
|
||||||
) -> None:
|
) -> None:
|
||||||
documents: List[Document] = await zep_retriever.aget_relevant_documents(
|
documents: List[Document] = await zep_retriever.aget_relevant_documents(
|
||||||
query="My trip to Iceland"
|
query="My trip to Iceland"
|
||||||
@ -91,7 +91,7 @@ async def test_zep_retriever_aget_relevant_documents(
|
|||||||
|
|
||||||
|
|
||||||
def _test_documents(
|
def _test_documents(
|
||||||
documents: List[Document], search_results: List[SearchResult]
|
documents: List[Document], search_results: List[MemorySearchResult]
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(documents) == 2
|
assert len(documents) == 2
|
||||||
for i, document in enumerate(documents):
|
for i, document in enumerate(documents):
|
||||||
|
Loading…
Reference in New Issue
Block a user