mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
3bae595182
Similar to #1813 for faiss, this PR is to extend functionality to pass text and its vector pair to initialize and add embeddings to the PGVector wrapper. Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @dev2049
187 lines
6.8 KiB
Python
187 lines
6.8 KiB
Python
"""Test PGVector functionality."""
|
|
import os
|
|
from typing import List
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.vectorstores.pgvector import PGVector
|
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
|
|
|
CONNECTION_STRING = PGVector.connection_string_from_db_params(
|
|
driver=os.environ.get("TEST_PGVECTOR_DRIVER", "psycopg2"),
|
|
host=os.environ.get("TEST_PGVECTOR_HOST", "localhost"),
|
|
port=int(os.environ.get("TEST_PGVECTOR_PORT", "5432")),
|
|
database=os.environ.get("TEST_PGVECTOR_DATABASE", "postgres"),
|
|
user=os.environ.get("TEST_PGVECTOR_USER", "postgres"),
|
|
password=os.environ.get("TEST_PGVECTOR_PASSWORD", "postgres"),
|
|
)
|
|
|
|
|
|
ADA_TOKEN_COUNT = 1536
|
|
|
|
|
|
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
|
"""Fake embeddings functionality for testing."""
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
"""Return simple embeddings."""
|
|
return [
|
|
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
|
]
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
"""Return simple embeddings."""
|
|
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
|
|
|
|
|
def test_pgvector() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo")]
|
|
|
|
|
|
def test_pgvector_embeddings() -> None:
|
|
"""Test end to end construction with embeddings and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts)
|
|
text_embedding_pairs = list(zip(texts, text_embeddings))
|
|
docsearch = PGVector.from_embeddings(
|
|
text_embeddings=text_embedding_pairs,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo")]
|
|
|
|
|
|
def test_pgvector_with_metadatas() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search("foo", k=1)
|
|
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
|
|
|
|
|
def test_pgvector_with_metadatas_with_scores() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1)
|
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
|
|
|
|
|
def test_pgvector_with_filter_match() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
|
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
|
|
|
|
|
def test_pgvector_with_filter_distant_match() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"})
|
|
assert output == [
|
|
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406)
|
|
]
|
|
|
|
|
|
def test_pgvector_with_filter_no_match() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
|
|
assert output == []
|
|
|
|
|
|
def test_pgvector_collection_with_metadata() -> None:
|
|
"""Test end to end collection construction"""
|
|
pgvector = PGVector(
|
|
collection_name="test_collection",
|
|
collection_metadata={"foo": "bar"},
|
|
embedding_function=FakeEmbeddingsWithAdaDimension(),
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
session = Session(pgvector.connect())
|
|
collection = pgvector.get_collection(session)
|
|
if collection is None:
|
|
assert False, "Expected a CollectionStore object but received None"
|
|
else:
|
|
assert collection.name == "test_collection"
|
|
assert collection.cmetadata == {"foo": "bar"}
|
|
|
|
|
|
def test_pgvector_with_filter_in_set() -> None:
|
|
"""Test end to end construction and search."""
|
|
texts = ["foo", "bar", "baz"]
|
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
|
docsearch = PGVector.from_texts(
|
|
texts=texts,
|
|
collection_name="test_collection_filter",
|
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
|
metadatas=metadatas,
|
|
connection_string=CONNECTION_STRING,
|
|
pre_delete_collection=True,
|
|
)
|
|
output = docsearch.similarity_search_with_score(
|
|
"foo", k=2, filter={"page": {"IN": ["0", "2"]}}
|
|
)
|
|
assert output == [
|
|
(Document(page_content="foo", metadata={"page": "0"}), 0.0),
|
|
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406),
|
|
]
|