2023-05-18 23:27:18 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import copy
|
|
|
|
from typing import TYPE_CHECKING, List
|
|
|
|
|
|
|
|
import pytest
|
2023-11-21 16:35:29 +00:00
|
|
|
from langchain_core.documents import Document
|
2023-05-18 23:27:18 +00:00
|
|
|
from pytest_mock import MockerFixture
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
from langchain_community.retrievers import ZepRetriever
|
2023-05-18 23:27:18 +00:00
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
2023-06-05 19:59:28 +00:00
|
|
|
from zep_python import MemorySearchResult, ZepClient
|
2023-05-18 23:27:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2023-06-05 19:59:28 +00:00
|
|
|
def search_results() -> List[MemorySearchResult]:
|
|
|
|
from zep_python import MemorySearchResult, Message
|
2023-05-18 23:27:18 +00:00
|
|
|
|
|
|
|
search_result = [
|
|
|
|
{
|
|
|
|
"message": {
|
|
|
|
"uuid": "66830914-19f5-490b-8677-1ba06bcd556b",
|
|
|
|
"created_at": "2023-05-18T20:40:42.743773Z",
|
|
|
|
"role": "user",
|
|
|
|
"content": "I'm looking to plan a trip to Iceland. Can you help me?",
|
|
|
|
"token_count": 17,
|
|
|
|
},
|
|
|
|
"summary": None,
|
|
|
|
"dist": 0.8734284910450115,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"message": {
|
|
|
|
"uuid": "015e618c-ba9d-45b6-95c3-77a8e611570b",
|
|
|
|
"created_at": "2023-05-18T20:40:42.743773Z",
|
|
|
|
"role": "user",
|
|
|
|
"content": "How much does a trip to Iceland typically cost?",
|
|
|
|
"token_count": 12,
|
|
|
|
},
|
|
|
|
"summary": None,
|
|
|
|
"dist": 0.8554048017463456,
|
|
|
|
},
|
|
|
|
]
|
|
|
|
|
|
|
|
return [
|
2023-06-05 19:59:28 +00:00
|
|
|
MemorySearchResult(
|
2023-05-18 23:27:18 +00:00
|
|
|
message=Message.parse_obj(result["message"]),
|
|
|
|
summary=result["summary"],
|
|
|
|
dist=result["dist"],
|
|
|
|
)
|
|
|
|
for result in search_result
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
@pytest.mark.requires("zep_python")
|
|
|
|
def zep_retriever(
|
2023-06-05 19:59:28 +00:00
|
|
|
mocker: MockerFixture, search_results: List[MemorySearchResult]
|
2023-05-18 23:27:18 +00:00
|
|
|
) -> ZepRetriever:
|
|
|
|
mock_zep_client: ZepClient = mocker.patch("zep_python.ZepClient", autospec=True)
|
2023-08-14 04:52:53 +00:00
|
|
|
mock_zep_client.memory = mocker.patch(
|
|
|
|
"zep_python.memory.client.MemoryClient", autospec=True
|
|
|
|
)
|
|
|
|
mock_zep_client.memory.search_memory.return_value = copy.deepcopy( # type: ignore
|
2023-05-18 23:27:18 +00:00
|
|
|
search_results
|
|
|
|
)
|
2023-08-14 04:52:53 +00:00
|
|
|
mock_zep_client.memory.asearch_memory.return_value = copy.deepcopy( # type: ignore
|
2023-05-18 23:27:18 +00:00
|
|
|
search_results
|
|
|
|
)
|
|
|
|
zep = ZepRetriever(session_id="123", url="http://localhost:8000")
|
|
|
|
zep.zep_client = mock_zep_client
|
|
|
|
return zep
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("zep_python")
|
2024-04-22 15:14:53 +00:00
|
|
|
def test_zep_retriever_invoke(
|
2023-06-05 19:59:28 +00:00
|
|
|
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
|
2023-05-18 23:27:18 +00:00
|
|
|
) -> None:
|
2024-04-22 15:14:53 +00:00
|
|
|
documents: List[Document] = zep_retriever.invoke("My trip to Iceland")
|
2023-05-18 23:27:18 +00:00
|
|
|
_test_documents(documents, search_results)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.requires("zep_python")
|
2024-04-22 15:14:53 +00:00
|
|
|
async def test_zep_retriever_ainvoke(
|
2023-06-05 19:59:28 +00:00
|
|
|
zep_retriever: ZepRetriever, search_results: List[MemorySearchResult]
|
2023-05-18 23:27:18 +00:00
|
|
|
) -> None:
|
2024-04-22 15:14:53 +00:00
|
|
|
documents: List[Document] = await zep_retriever.ainvoke("My trip to Iceland")
|
2023-05-18 23:27:18 +00:00
|
|
|
_test_documents(documents, search_results)
|
|
|
|
|
|
|
|
|
|
|
|
def _test_documents(
|
2023-06-05 19:59:28 +00:00
|
|
|
documents: List[Document], search_results: List[MemorySearchResult]
|
2023-05-18 23:27:18 +00:00
|
|
|
) -> None:
|
|
|
|
assert len(documents) == 2
|
|
|
|
for i, document in enumerate(documents):
|
|
|
|
assert document.page_content == search_results[i].message.get( # type: ignore
|
|
|
|
"content"
|
|
|
|
)
|
2023-10-31 14:53:12 +00:00
|
|
|
assert document.metadata.get("uuid") == search_results[i].message.get( # type: ignore
|
2023-05-18 23:27:18 +00:00
|
|
|
"uuid"
|
|
|
|
)
|
2023-10-31 14:53:12 +00:00
|
|
|
assert document.metadata.get("role") == search_results[i].message.get( # type: ignore
|
2023-05-18 23:27:18 +00:00
|
|
|
"role"
|
|
|
|
)
|
|
|
|
assert document.metadata.get("score") == search_results[i].dist
|