mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
5171c3bcca
Description: This pull request aims to support generating the correct generic relevancy scores for different vector stores by refactoring the relevance score functions and their selection in the base class and subclasses of VectorStore. This is especially relevant with VectorStores that require a distance metric upon initialization. Note many of the current implenetations of `_similarity_search_with_relevance_scores` are not technically correct, as they just return `self.similarity_search_with_score(query, k, **kwargs)` without applying the relevant score function Also includes changes associated with: https://github.com/hwchase17/langchain/pull/6564 and https://github.com/hwchase17/langchain/pull/6494 See more indepth discussion in thread in #6494 Issue: https://github.com/hwchase17/langchain/issues/6526 https://github.com/hwchase17/langchain/issues/6481 https://github.com/hwchase17/langchain/issues/6346 Dependencies: None The changes include: - Properly handling score thresholding in FAISS `similarity_search_with_score_by_vector` for the corresponding distance metric. - Refactoring the `_similarity_search_with_relevance_scores` method in the base class and removing it from the subclasses for incorrectly implemented subclasses. - Adding a `_select_relevance_score_fn` method in the base class and implementing it in the subclasses to select the appropriate relevance score function based on the distance strategy. - Updating the `__init__` methods of the subclasses to set the `relevance_score_fn` attribute. - Removing the `_default_relevance_score_fn` function from the FAISS class and using the base class's `_euclidean_relevance_score_fn` instead. - Adding the `DistanceStrategy` enum to the `utils.py` file and updating the imports in the vector store classes. - Updating the tests to import the `DistanceStrategy` enum from the `utils.py` file. --------- Co-authored-by: Hanit <37485638+hanit-com@users.noreply.github.com>
254 lines
9.2 KiB
Python
254 lines
9.2 KiB
Python
"""Test PGVector functionality."""
|
|
import os
|
|
from typing import List
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.vectorstores.pgvector import PGVector
|
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
|
|
|
CONNECTION_STRING = PGVector.connection_string_from_db_params(
|
|
driver=os.environ.get("TEST_PGVECTOR_DRIVER", "psycopg2"),
|
|
host=os.environ.get("TEST_PGVECTOR_HOST", "localhost"),
|
|
port=int(os.environ.get("TEST_PGVECTOR_PORT", "5432")),
|
|
database=os.environ.get("TEST_PGVECTOR_DATABASE", "postgres"),
|
|
user=os.environ.get("TEST_PGVECTOR_USER", "postgres"),
|
|
password=os.environ.get("TEST_PGVECTOR_PASSWORD", "postgres"),
|
|
)
|
|
|
|
|
|
ADA_TOKEN_COUNT = 1536
|
|
|
|
|
|
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
|
"""Fake embeddings functionality for testing."""
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Return simple embeddings."""
|
|
return [
|
|
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
|
]
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Return simple embeddings."""
|
|
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
|
|
|
|
|
def test_pgvector() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo")]
|
|
|
|
|
|
def test_pgvector_embeddings() -> None:
|
|
"""Test end to end construction with embeddings and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts)
|
|
text_embedding_pairs = list(zip(texts, text_embeddings))
|
|
docsearch = PGVector.from_embeddings(
|
|
text_embeddings=text_embedding_pairs,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo")]
|
|
|
|
|
|
def test_pgvector_with_metadatas() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
|
|
|
|
|
def test_pgvector_with_metadatas_with_scores() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1)
|
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
|
|
|
|
|
def test_pgvector_with_filter_match() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
|
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
|
|
|
|
|
def test_pgvector_with_filter_distant_match() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"})
|
|
assert output == [
|
|
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406)
|
|
]
|
|
|
|
|
|
def test_pgvector_with_filter_no_match() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
|
|
assert output == []
|
|
|
|
|
|
def test_pgvector_collection_with_metadata() -> None:
|
|
"""Test end to end collection construction"""
|
|
pgvector = PGVector(
|
|
collection_name="test_collection",
|
|
collection_metadata={"foo": "bar"},
|
|
embedding_function=FakeEmbeddingsWithAdaDimension(),
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
session = Session(pgvector.connect())
|
|
collection = pgvector.get_collection(session)
|
|
if collection is None:
|
|
assert False, "Expected a CollectionStore object but received None"
|
|
else:
|
|
assert collection.name == "test_collection"
|
|
assert collection.cmetadata == {"foo": "bar"}
|
|
|
|
|
|
def test_pgvector_with_filter_in_set() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score(
|
|
"foo", k=2, filter={"page": {"IN": ["0", "2"]}}
|
|
)
|
|
assert output == [
|
|
(Document(page_content="foo", metadata={"page": "0"}), 0.0),
|
|
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406),
|
|
]
|
|
|
|
|
|
def test_pgvector_relevance_score() -> None:
|
|
"""Test to make sure the relevance score is scaled to 0-1."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
|
|
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
|
assert output == [
|
|
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
|
(Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065),
|
|
(Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621),
|
|
]
|
|
|
|
|
|
def test_pgvector_retriever_search_threshold() -> None:
|
|
"""Test using retriever for searching with threshold."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
|
|
retriever = docsearch.as_retriever(
|
|
search_type="similarity_score_threshold",
|
|
search_kwargs={"k": 3, "score_threshold": 0.999},
|
|
)
|
|
output = retriever.get_relevant_documents("summer")
|
|
assert output == [
|
|
Document(page_content="foo", metadata={"page": "0"}),
|
|
Document(page_content="bar", metadata={"page": "1"}),
|
|
]
|
|
|
|
|
|
def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
|
|
"""Test searching with threshold and custom normalization function"""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
relevance_score_fn=lambda d: d * 0,
|
|
)
|
|
|
|
retriever = docsearch.as_retriever(
|
|
search_type="similarity_score_threshold",
|
|
search_kwargs={"k": 3, "score_threshold": 0.5},
|
|
)
|
|
output = retriever.get_relevant_documents("foo")
|
|
assert output == []
|