mirror of https://github.com/hwchase17/langchain
Add a full PostgresSQL syntax database 'AnalyticDB' as vector store. (#3135)
Hi there!
I'm excited to open this PR to add support for using a fully Postgres
syntax compatible database 'AnalyticDB' as a vector.
As AnalyticDB has been proved can be used with AutoGPT,
ChatGPT-Retrieve-Plugin, and LLama-Index, I think it is also good for
you.
AnalyticDB is a distributed Alibaba Cloud-Native vector database. It
works better when data comes to large scale. The PR includes:
- [x] A new memory: AnalyticDBVector
- [x] A suite of integration tests verifies the AnalyticDB integration
I have read your [contributing
guidelines](72b7d76d79/.github/CONTRIBUTING.md
).
And I have passed the tests below
- [x] make format
- [x] make lint
- [x] make coverage
- [x] make test
pull/3352/head
parent
cc6fe18152
commit
88a8f59aa7
@ -0,0 +1,15 @@
|
|||||||
|
# AnalyticDB
|
||||||
|
|
||||||
|
This page covers how to use the AnalyticDB ecosystem within LangChain.
|
||||||
|
|
||||||
|
### VectorStore
|
||||||
|
|
||||||
|
There exists a wrapper around AnalyticDB, allowing you to use it as a vectorstore,
|
||||||
|
whether for semantic search or example selection.
|
||||||
|
|
||||||
|
To import this vectorstore:
|
||||||
|
```python
|
||||||
|
from langchain.vectorstores import AnalyticDB
|
||||||
|
```
|
||||||
|
|
||||||
|
For a more detailed walkthrough of the AnalyticDB wrapper, see [this notebook](../modules/indexes/vectorstores/examples/analyticdb.ipynb)
|
@ -0,0 +1,432 @@
|
|||||||
|
"""VectorStore wrapper around a Postgres/PGVector database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from sqlalchemy import REAL, Index
|
||||||
|
from sqlalchemy.dialects.postgresql import ARRAY, JSON, UUID
|
||||||
|
from sqlalchemy.orm import Mapped, Session, declarative_base, relationship
|
||||||
|
from sqlalchemy.sql.expression import func
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
|
||||||
|
Base = declarative_base() # type: Any
|
||||||
|
|
||||||
|
|
||||||
|
ADA_TOKEN_COUNT = 1536
|
||||||
|
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(Base):
|
||||||
|
__abstract__ = True
|
||||||
|
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionStore(BaseModel):
|
||||||
|
__tablename__ = "langchain_pg_collection"
|
||||||
|
|
||||||
|
name = sqlalchemy.Column(sqlalchemy.String)
|
||||||
|
cmetadata = sqlalchemy.Column(JSON)
|
||||||
|
|
||||||
|
embeddings = relationship(
|
||||||
|
"EmbeddingStore",
|
||||||
|
back_populates="collection",
|
||||||
|
passive_deletes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
|
||||||
|
return session.query(cls).filter(cls.name == name).first()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_or_create(
|
||||||
|
cls,
|
||||||
|
session: Session,
|
||||||
|
name: str,
|
||||||
|
cmetadata: Optional[dict] = None,
|
||||||
|
) -> Tuple["CollectionStore", bool]:
|
||||||
|
"""
|
||||||
|
Get or create a collection.
|
||||||
|
Returns [Collection, bool] where the bool is True if the collection was created.
|
||||||
|
"""
|
||||||
|
created = False
|
||||||
|
collection = cls.get_by_name(session, name)
|
||||||
|
if collection:
|
||||||
|
return collection, created
|
||||||
|
|
||||||
|
collection = cls(name=name, cmetadata=cmetadata)
|
||||||
|
session.add(collection)
|
||||||
|
session.commit()
|
||||||
|
created = True
|
||||||
|
return collection, created
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingStore(BaseModel):
|
||||||
|
__tablename__ = "langchain_pg_embedding"
|
||||||
|
|
||||||
|
collection_id: Mapped[UUID] = sqlalchemy.Column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
sqlalchemy.ForeignKey(
|
||||||
|
f"{CollectionStore.__tablename__}.uuid",
|
||||||
|
ondelete="CASCADE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||||
|
|
||||||
|
embedding = sqlalchemy.Column(ARRAY(REAL))
|
||||||
|
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||||
|
|
||||||
|
# custom_id : any user defined id
|
||||||
|
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||||
|
|
||||||
|
# The following line creates an index named 'langchain_pg_embedding_vector_idx'
|
||||||
|
langchain_pg_embedding_vector_idx = Index(
|
||||||
|
"langchain_pg_embedding_vector_idx",
|
||||||
|
embedding,
|
||||||
|
postgresql_using="ann",
|
||||||
|
postgresql_with={
|
||||||
|
"distancemeasure": "L2",
|
||||||
|
"dim": 1536,
|
||||||
|
"pq_segments": 64,
|
||||||
|
"hnsw_m": 100,
|
||||||
|
"pq_centers": 2048,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult:
|
||||||
|
EmbeddingStore: EmbeddingStore
|
||||||
|
distance: float
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticDB(VectorStore):
|
||||||
|
"""
|
||||||
|
VectorStore implementation using AnalyticDB.
|
||||||
|
AnalyticDB is a distributed full PostgresSQL syntax cloud-native database.
|
||||||
|
- `connection_string` is a postgres connection string.
|
||||||
|
- `embedding_function` any embedding function implementing
|
||||||
|
`langchain.embeddings.base.Embeddings` interface.
|
||||||
|
- `collection_name` is the name of the collection to use. (default: langchain)
|
||||||
|
- NOTE: This is not the name of the table, but the name of the collection.
|
||||||
|
The tables will be created when initializing the store (if not exists)
|
||||||
|
So, make sure the user has the right permissions to create tables.
|
||||||
|
- `pre_delete_collection` if True, will delete the collection if it exists.
|
||||||
|
(default: False)
|
||||||
|
- Useful for testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection_string: str,
|
||||||
|
embedding_function: Embeddings,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
collection_metadata: Optional[dict] = None,
|
||||||
|
pre_delete_collection: bool = False,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
self.connection_string = connection_string
|
||||||
|
self.embedding_function = embedding_function
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.collection_metadata = collection_metadata
|
||||||
|
self.pre_delete_collection = pre_delete_collection
|
||||||
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
|
self.__post_init__()
|
||||||
|
|
||||||
|
def __post_init__(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the store.
|
||||||
|
"""
|
||||||
|
self._conn = self.connect()
|
||||||
|
self.create_tables_if_not_exists()
|
||||||
|
self.create_collection()
|
||||||
|
|
||||||
|
def connect(self) -> sqlalchemy.engine.Connection:
|
||||||
|
engine = sqlalchemy.create_engine(self.connection_string)
|
||||||
|
conn = engine.connect()
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def create_tables_if_not_exists(self) -> None:
|
||||||
|
Base.metadata.create_all(self._conn)
|
||||||
|
|
||||||
|
def drop_tables(self) -> None:
|
||||||
|
Base.metadata.drop_all(self._conn)
|
||||||
|
|
||||||
|
def create_collection(self) -> None:
|
||||||
|
if self.pre_delete_collection:
|
||||||
|
self.delete_collection()
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
CollectionStore.get_or_create(
|
||||||
|
session, self.collection_name, cmetadata=self.collection_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_collection(self) -> None:
|
||||||
|
self.logger.debug("Trying to delete collection")
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.get_collection(session)
|
||||||
|
if not collection:
|
||||||
|
self.logger.error("Collection not found")
|
||||||
|
return
|
||||||
|
session.delete(collection)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
|
||||||
|
return CollectionStore.get_by_name(session, self.collection_name)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of strings to add to the vectorstore.
|
||||||
|
metadatas: Optional list of metadatas associated with the texts.
|
||||||
|
kwargs: vectorstore specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ids from adding the texts into the vectorstore.
|
||||||
|
"""
|
||||||
|
if ids is None:
|
||||||
|
ids = [str(uuid.uuid1()) for _ in texts]
|
||||||
|
|
||||||
|
embeddings = self.embedding_function.embed_documents(list(texts))
|
||||||
|
|
||||||
|
if not metadatas:
|
||||||
|
metadatas = [{} for _ in texts]
|
||||||
|
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.get_collection(session)
|
||||||
|
if not collection:
|
||||||
|
raise ValueError("Collection not found")
|
||||||
|
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
||||||
|
embedding_store = EmbeddingStore(
|
||||||
|
embedding=embedding,
|
||||||
|
document=text,
|
||||||
|
cmetadata=metadata,
|
||||||
|
custom_id=id,
|
||||||
|
)
|
||||||
|
collection.embeddings.append(embedding_store)
|
||||||
|
session.add(embedding_store)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Run similarity search with AnalyticDB with distance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): Query text to search for.
|
||||||
|
k (int): Number of results to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query.
|
||||||
|
"""
|
||||||
|
embedding = self.embedding_function.embed_query(text=query)
|
||||||
|
return self.similarity_search_by_vector(
|
||||||
|
embedding=embedding,
|
||||||
|
k=k,
|
||||||
|
filter=filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query and score for each
|
||||||
|
"""
|
||||||
|
embedding = self.embedding_function.embed_query(query)
|
||||||
|
docs = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding=embedding, k=k, filter=filter
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.get_collection(session)
|
||||||
|
if not collection:
|
||||||
|
raise ValueError("Collection not found")
|
||||||
|
|
||||||
|
filter_by = EmbeddingStore.collection_id == collection.uuid
|
||||||
|
|
||||||
|
if filter is not None:
|
||||||
|
filter_clauses = []
|
||||||
|
for key, value in filter.items():
|
||||||
|
filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value)
|
||||||
|
filter_clauses.append(filter_by_metadata)
|
||||||
|
|
||||||
|
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
|
||||||
|
|
||||||
|
results: List[QueryResult] = (
|
||||||
|
session.query(
|
||||||
|
EmbeddingStore,
|
||||||
|
func.l2_distance(EmbeddingStore.embedding, embedding).label("distance"),
|
||||||
|
)
|
||||||
|
.filter(filter_by)
|
||||||
|
.order_by(EmbeddingStore.embedding.op("<->")(embedding))
|
||||||
|
.join(
|
||||||
|
CollectionStore,
|
||||||
|
EmbeddingStore.collection_id == CollectionStore.uuid,
|
||||||
|
)
|
||||||
|
.limit(k)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
docs = [
|
||||||
|
(
|
||||||
|
Document(
|
||||||
|
page_content=result.EmbeddingStore.document,
|
||||||
|
metadata=result.EmbeddingStore.cmetadata,
|
||||||
|
),
|
||||||
|
result.distance if self.embedding_function is not None else None,
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
filter: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to embedding vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query vector.
|
||||||
|
"""
|
||||||
|
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||||
|
embedding=embedding, k=k, filter=filter
|
||||||
|
)
|
||||||
|
return [doc for doc, _ in docs_and_scores]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
pre_delete_collection: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AnalyticDB:
|
||||||
|
"""
|
||||||
|
Return VectorStore initialized from texts and embeddings.
|
||||||
|
Postgres connection string is required
|
||||||
|
Either pass it as a parameter
|
||||||
|
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
connection_string = cls.get_connection_string(kwargs)
|
||||||
|
|
||||||
|
store = cls(
|
||||||
|
connection_string=connection_string,
|
||||||
|
collection_name=collection_name,
|
||||||
|
embedding_function=embedding,
|
||||||
|
pre_delete_collection=pre_delete_collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
|
||||||
|
return store
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
|
||||||
|
connection_string: str = get_from_dict_or_env(
|
||||||
|
data=kwargs,
|
||||||
|
key="connection_string",
|
||||||
|
env_key="PGVECTOR_CONNECTION_STRING",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not connection_string:
|
||||||
|
raise ValueError(
|
||||||
|
"Postgres connection string is required"
|
||||||
|
"Either pass it as a parameter"
|
||||||
|
"or set the PGVECTOR_CONNECTION_STRING environment variable."
|
||||||
|
)
|
||||||
|
|
||||||
|
return connection_string
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_documents(
|
||||||
|
cls,
|
||||||
|
documents: List[Document],
|
||||||
|
embedding: Embeddings,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
pre_delete_collection: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AnalyticDB:
|
||||||
|
"""
|
||||||
|
Return VectorStore initialized from documents and embeddings.
|
||||||
|
Postgres connection string is required
|
||||||
|
Either pass it as a parameter
|
||||||
|
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
texts = [d.page_content for d in documents]
|
||||||
|
metadatas = [d.metadata for d in documents]
|
||||||
|
connection_string = cls.get_connection_string(kwargs)
|
||||||
|
|
||||||
|
kwargs["connection_string"] = connection_string
|
||||||
|
|
||||||
|
return cls.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
pre_delete_collection=pre_delete_collection,
|
||||||
|
embedding=embedding,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids,
|
||||||
|
collection_name=collection_name,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def connection_string_from_db_params(
|
||||||
|
cls,
|
||||||
|
driver: str,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
database: str,
|
||||||
|
user: str,
|
||||||
|
password: str,
|
||||||
|
) -> str:
|
||||||
|
"""Return connection string from database parameters."""
|
||||||
|
return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
|
@ -0,0 +1,148 @@
|
|||||||
|
"""Test PGVector functionality."""
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.vectorstores.analyticdb import AnalyticDB
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||||
|
|
||||||
|
CONNECTION_STRING = AnalyticDB.connection_string_from_db_params(
|
||||||
|
driver=os.environ.get("PG_DRIVER", "psycopg2cffi"),
|
||||||
|
host=os.environ.get("PG_HOST", "localhost"),
|
||||||
|
port=int(os.environ.get("PG_HOST", "5432")),
|
||||||
|
database=os.environ.get("PG_DATABASE", "postgres"),
|
||||||
|
user=os.environ.get("PG_USER", "postgres"),
|
||||||
|
password=os.environ.get("PG_PASSWORD", "postgres"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ADA_TOKEN_COUNT = 1536
|
||||||
|
|
||||||
|
|
||||||
|
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
|
||||||
|
"""Fake embeddings functionality for testing."""
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Return simple embeddings."""
|
||||||
|
return [
|
||||||
|
[float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(i)] for i in range(len(texts))
|
||||||
|
]
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Return simple embeddings."""
|
||||||
|
return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyticdb() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = AnalyticDB.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyticdb_with_metadatas() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = AnalyticDB.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo", metadata={"page": "0"})]
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyticdb_with_metadatas_with_scores() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = AnalyticDB.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=1)
|
||||||
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyticdb_with_filter_match() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = AnalyticDB.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection_filter",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
|
||||||
|
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyticdb_with_filter_distant_match() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = AnalyticDB.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection_filter",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"})
|
||||||
|
print(output)
|
||||||
|
assert output == [(Document(page_content="baz", metadata={"page": "2"}), 4.0)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyticdb_with_filter_no_match() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||||
|
docsearch = AnalyticDB.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
collection_name="test_collection_filter",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
|
||||||
|
assert output == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_analyticdb_collection_with_metadata() -> None:
|
||||||
|
"""Test end to end collection construction"""
|
||||||
|
pgvector = AnalyticDB(
|
||||||
|
collection_name="test_collection",
|
||||||
|
collection_metadata={"foo": "bar"},
|
||||||
|
embedding_function=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
session = Session(pgvector.connect())
|
||||||
|
collection = pgvector.get_collection(session)
|
||||||
|
if collection is None:
|
||||||
|
assert False, "Expected a CollectionStore object but received None"
|
||||||
|
else:
|
||||||
|
assert collection.name == "test_collection"
|
||||||
|
assert collection.cmetadata == {"foo": "bar"}
|
Loading…
Reference in New Issue