From 549720ae51702cdf6a31eb176d2e3a769f6fc81e Mon Sep 17 00:00:00 2001 From: shibuiwilliam Date: Mon, 31 Jul 2023 03:42:25 +0900 Subject: [PATCH] add test to ensure values in time weighted retriever are updated (#8479) # What - add test to ensure values in time weighted retriever are updated --- .../retrievers/test_time_weighted_retriever.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py index d021ed93cb..2075a49b5b 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py @@ -1,6 +1,6 @@ """Tests for the time-weighted retriever class.""" -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Iterable, List, Optional, Tuple, Type import pytest @@ -139,7 +139,11 @@ def test_get_salient_docs( ) -> None: query = "Test query" docs_and_scores = time_weighted_retriever.get_salient_docs(query) + want = [(doc, 0.5) for doc in _get_example_memories()] assert isinstance(docs_and_scores, dict) + assert len(docs_and_scores) == len(want) + for k, doc in docs_and_scores.items(): + assert doc in want def test_get_relevant_documents( @@ -147,7 +151,17 @@ def test_get_relevant_documents( ) -> None: query = "Test query" relevant_documents = time_weighted_retriever.get_relevant_documents(query) + want = [(doc, 0.5) for doc in _get_example_memories()] assert isinstance(relevant_documents, list) + assert len(relevant_documents) == len(want) + now = datetime.now() + for doc in relevant_documents: + # assert that the last_accessed_at is close to now. + assert now - timedelta(hours=1) < doc.metadata["last_accessed_at"] <= now + + # assert that the last_accessed_at in the memory stream is updated. + for d in time_weighted_retriever.memory_stream: + assert now - timedelta(hours=1) < d.metadata["last_accessed_at"] <= now def test_add_documents(