Improve AnalyticDB Vector Store implementation without affecting user (#6086)

Hi there:

As I implement the AnalyticDB VectorStore use two table to store the
document before. It seems just use one table is a better way. So this
commit is try to improve AnalyticDB VectorStore implementation without
affecting user behavior:

**1. Streamline the `post_init `behavior by creating a single table with
vector indexing.
2. Update the `add_texts` API for document insertion.
3. Optimize `similarity_search_with_score_by_vector` to retrieve results
directly from the table.
4. Implement `_similarity_search_with_relevance_scores`.
5. Add `embedding_dimension` parameter to support different dimension
embedding functions.**

Users can continue using the API as before. 
Test cases added before is enough to meet this commit.
This commit is contained in:
Richy Wang 2023-06-18 00:36:31 +08:00 committed by GitHub
parent cdd1d78bf2
commit 444ca3f669
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 145 additions and 186 deletions

View File

@ -3,114 +3,28 @@ from __future__ import annotations
import logging
import uuid
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type
import sqlalchemy
from sqlalchemy import REAL, Index
from sqlalchemy.dialects.postgresql import ARRAY, JSON, UUID
from sqlalchemy import REAL, Column, String, Table, create_engine, insert, text
from sqlalchemy.dialects.postgresql import ARRAY, JSON, TEXT
from sqlalchemy.engine import Row
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
from sqlalchemy.sql.expression import func
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
_LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document"
Base = declarative_base() # type: Any
ADA_TOKEN_COUNT = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
class BaseModel(Base):
__abstract__ = True
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
class CollectionStore(BaseModel):
__tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class EmbeddingStore(BaseModel):
__tablename__ = "langchain_pg_embedding"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: sqlalchemy.Column = sqlalchemy.Column(ARRAY(REAL))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
# The following line creates an index named 'langchain_pg_embedding_vector_idx'
langchain_pg_embedding_vector_idx = Index(
"langchain_pg_embedding_vector_idx",
embedding,
postgresql_using="ann",
postgresql_with={
"distancemeasure": "L2",
"dim": 1536,
"pq_segments": 64,
"hnsw_m": 100,
"pq_centers": 2048,
},
)
class QueryResult:
EmbeddingStore: EmbeddingStore
distance: float
class AnalyticDB(VectorStore):
"""
VectorStore implementation using AnalyticDB.
@ -132,15 +46,15 @@ class AnalyticDB(VectorStore):
self,
connection_string: str,
embedding_function: Embeddings,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None,
) -> None:
self.connection_string = connection_string
self.embedding_function = embedding_function
self.embedding_dimension = embedding_dimension
self.collection_name = collection_name
self.collection_metadata = collection_metadata
self.pre_delete_collection = pre_delete_collection
self.logger = logger or logging.getLogger(__name__)
self.__post_init__()
@ -151,47 +65,68 @@ class AnalyticDB(VectorStore):
"""
Initialize the store.
"""
self._conn = self.connect()
self.create_tables_if_not_exists()
self.engine = create_engine(self.connection_string)
self.create_collection()
def connect(self) -> sqlalchemy.engine.Connection:
engine = sqlalchemy.create_engine(self.connection_string)
conn = engine.connect()
return conn
def create_table_if_not_exists(self) -> None:
# Define the dynamic table
Table(
self.collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True, default=uuid.uuid4),
Column("embedding", ARRAY(REAL)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
with self.engine.connect() as conn:
# Create the table
Base.metadata.create_all(conn)
def create_tables_if_not_exists(self) -> None:
Base.metadata.create_all(self._conn)
# Check if the index exists
index_name = f"{self.collection_name}_embedding_idx"
index_query = text(
f"""
SELECT 1
FROM pg_indexes
WHERE indexname = '{index_name}';
"""
)
result = conn.execute(index_query).scalar()
def drop_tables(self) -> None:
Base.metadata.drop_all(self._conn)
# Create the index if it doesn't exist
if not result:
index_statement = text(
f"""
CREATE INDEX {index_name}
ON {self.collection_name} USING ann(embedding)
WITH (
"dim" = {self.embedding_dimension},
"hnsw_m" = 100
);
"""
)
conn.execute(index_statement)
conn.commit()
def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
with Session(self._conn) as session:
CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
self.create_table_if_not_exists()
def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
with Session(self._conn) as session:
collection = self.get_collection(session)
if not collection:
self.logger.error("Collection not found")
return
session.delete(collection)
session.commit()
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
return CollectionStore.get_by_name(session, self.collection_name)
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
with self.engine.connect() as conn:
conn.execute(drop_statement)
conn.commit()
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 500,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
@ -212,20 +147,43 @@ class AnalyticDB(VectorStore):
if not metadatas:
metadatas = [{} for _ in texts]
with Session(self._conn) as session:
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
embedding_store = EmbeddingStore(
embedding=embedding,
document=text,
cmetadata=metadata,
custom_id=id,
# Define the table schema
chunks_table = Table(
self.collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True),
Column("embedding", ARRAY(REAL)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
chunks_table_data = []
with self.engine.connect() as conn:
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
)
collection.embeddings.append(embedding_store)
session.add(embedding_store)
session.commit()
# Execute the batch insert when the batch size is reached
if len(chunks_table_data) == batch_size:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Commit the transaction only once after all records have been inserted
conn.commit()
return ids
@ -275,52 +233,69 @@ class AnalyticDB(VectorStore):
)
return docs
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
return self.similarity_search_with_score(query, k, **kwargs)
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]:
with Session(self._conn) as session:
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
filter_by = EmbeddingStore.collection_id == collection.uuid
# Add the filter if provided
filter_condition = ""
if filter is not None:
filter_clauses = []
for key, value in filter.items():
filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value)
filter_clauses.append(filter_by_metadata)
conditions = [
f"metadata->>{key!r} = {value!r}" for key, value in filter.items()
]
filter_condition = f"WHERE {' AND '.join(conditions)}"
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
# Define the base query
sql_query = f"""
SELECT *, l2_distance(embedding, :embedding) as distance
FROM {self.collection_name}
{filter_condition}
ORDER BY embedding <-> :embedding
LIMIT :k
"""
results: List[QueryResult] = (
session.query(
EmbeddingStore,
func.l2_distance(EmbeddingStore.embedding, embedding).label("distance"),
)
.filter(filter_by)
.order_by(EmbeddingStore.embedding.op("<->")(embedding))
.join(
CollectionStore,
EmbeddingStore.collection_id == CollectionStore.uuid,
)
.limit(k)
.all()
)
docs = [
# Set up the query parameters
params = {"embedding": embedding, "k": k}
# Execute the query and fetch the results
with self.engine.connect() as conn:
results: Sequence[Row] = conn.execute(text(sql_query), params).fetchall()
documents_with_scores = [
(
Document(
page_content=result.EmbeddingStore.document,
metadata=result.EmbeddingStore.cmetadata,
page_content=result.document,
metadata=result.metadata,
),
result.distance if self.embedding_function is not None else None,
)
for result in results
]
return docs
return documents_with_scores
def similarity_search_by_vector(
self,
@ -346,10 +321,11 @@ class AnalyticDB(VectorStore):
@classmethod
def from_texts(
cls,
cls: Type[AnalyticDB],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
@ -368,6 +344,7 @@ class AnalyticDB(VectorStore):
connection_string=connection_string,
collection_name=collection_name,
embedding_function=embedding,
embedding_dimension=embedding_dimension,
pre_delete_collection=pre_delete_collection,
)
@ -379,7 +356,7 @@ class AnalyticDB(VectorStore):
connection_string: str = get_from_dict_or_env(
data=kwargs,
key="connection_string",
env_key="PGVECTOR_CONNECTION_STRING",
env_key="PG_CONNECTION_STRING",
)
if not connection_string:
@ -393,9 +370,10 @@ class AnalyticDB(VectorStore):
@classmethod
def from_documents(
cls,
cls: Type[AnalyticDB],
documents: List[Document],
embedding: Embeddings,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
@ -418,6 +396,7 @@ class AnalyticDB(VectorStore):
texts=texts,
pre_delete_collection=pre_delete_collection,
embedding=embedding,
embedding_dimension=embedding_dimension,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,

View File

@ -2,8 +2,6 @@
import os
from typing import List
from sqlalchemy.orm import Session
from langchain.docstore.document import Document
from langchain.vectorstores.analyticdb import AnalyticDB
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@ -11,7 +9,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
CONNECTION_STRING = AnalyticDB.connection_string_from_db_params(
driver=os.environ.get("PG_DRIVER", "psycopg2cffi"),
host=os.environ.get("PG_HOST", "localhost"),
port=int(os.environ.get("PG_HOST", "5432")),
port=int(os.environ.get("PG_PORT", "5432")),
database=os.environ.get("PG_DATABASE", "postgres"),
user=os.environ.get("PG_USER", "postgres"),
password=os.environ.get("PG_PASSWORD", "postgres"),
@ -128,21 +126,3 @@ def test_analyticdb_with_filter_no_match() -> None:
)
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
assert output == []
def test_analyticdb_collection_with_metadata() -> None:
"""Test end to end collection construction"""
pgvector = AnalyticDB(
collection_name="test_collection",
collection_metadata={"foo": "bar"},
embedding_function=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
session = Session(pgvector.connect())
collection = pgvector.get_collection(session)
if collection is None:
assert False, "Expected a CollectionStore object but received None"
else:
assert collection.name == "test_collection"
assert collection.cmetadata == {"foo": "bar"}