You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/tests/unit_tests/retrievers/test_time_weighted_retrieve...

164 lines
4.9 KiB
Python

"""Tests for the time-weighted retriever class."""
from datetime import datetime
from typing import Any, Iterable, List, Optional, Tuple, Type
import pytest
from langchain.embeddings.base import Embeddings
from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
_get_hours_passed,
)
from langchain.schema import Document
from langchain.vectorstores.base import VectorStore
def _get_example_memories(k: int = 4) -> List[Document]:
return [
Document(
page_content="foo",
metadata={
"buffer_idx": i,
"last_accessed_at": datetime(2023, 4, 14, 12, 0),
},
)
for i in range(k)
]
class MockVectorStore(VectorStore):
"""Mock invalid vector store."""
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters
Returns:
List of ids from adding the texts into the vectorstore.
"""
return list(texts)
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore."""
raise NotImplementedError
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
return []
@classmethod
def from_documents(
cls: Type["MockVectorStore"],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
@classmethod
def from_texts(
cls: Type["MockVectorStore"],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> "MockVectorStore":
"""Return VectorStore initialized from texts and embeddings."""
return cls()
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and similarity scores, normalized on a scale from 0 to 1.
0 is dissimilar, 1 is most similar.
"""
return [(doc, 0.5) for doc in _get_example_memories()]
@pytest.fixture
def time_weighted_retriever() -> TimeWeightedVectorStoreRetriever:
vectorstore = MockVectorStore()
return TimeWeightedVectorStoreRetriever(
vectorstore=vectorstore, memory_stream=_get_example_memories()
)
def test__get_hours_passed() -> None:
time1 = datetime(2023, 4, 14, 14, 30)
time2 = datetime(2023, 4, 14, 12, 0)
expected_hours_passed = 2.5
hours_passed = _get_hours_passed(time1, time2)
assert hours_passed == expected_hours_passed
def test_get_combined_score(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
document = Document(
page_content="Test document",
metadata={"last_accessed_at": datetime(2023, 4, 14, 12, 0)},
)
vector_salience = 0.7
expected_hours_passed = 2.5
current_time = datetime(2023, 4, 14, 14, 30)
combined_score = time_weighted_retriever._get_combined_score(
document, vector_salience, current_time
)
expected_score = (
1.0 - time_weighted_retriever.decay_rate
) ** expected_hours_passed + vector_salience
assert combined_score == pytest.approx(expected_score)
def test_get_salient_docs(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
docs_and_scores = time_weighted_retriever.get_salient_docs(query)
assert isinstance(docs_and_scores, dict)
def test_get_relevant_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
query = "Test query"
relevant_documents = time_weighted_retriever.get_relevant_documents(query)
assert isinstance(relevant_documents, list)
def test_add_documents(
time_weighted_retriever: TimeWeightedVectorStoreRetriever,
) -> None:
documents = [Document(page_content="test_add_documents document")]
added_documents = time_weighted_retriever.add_documents(documents)
assert isinstance(added_documents, list)
assert len(added_documents) == 1
assert (
time_weighted_retriever.memory_stream[-1].page_content
== documents[0].page_content
)