Improve AnalyticDB Vector Store implementation without affecting user (#6086)

Hi there:

As I implement the AnalyticDB VectorStore use two table to store the
document before. It seems just use one table is a better way. So this
commit is try to improve AnalyticDB VectorStore implementation without
affecting user behavior:

**1. Streamline the `post_init `behavior by creating a single table with
vector indexing.
2. Update the `add_texts` API for document insertion.
3. Optimize `similarity_search_with_score_by_vector` to retrieve results
directly from the table.
4. Implement `_similarity_search_with_relevance_scores`.
5. Add `embedding_dimension` parameter to support different dimension
embedding functions.**

Users can continue using the API as before. 
Test cases added before is enough to meet this commit.
This commit is contained in:
Richy Wang 2023-06-18 00:36:31 +08:00 committed by GitHub
parent cdd1d78bf2
commit 444ca3f669
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 145 additions and 186 deletions

View File

@ -3,114 +3,28 @@ from __future__ import annotations
import logging import logging
import uuid import uuid
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type
import sqlalchemy from sqlalchemy import REAL, Column, String, Table, create_engine, insert, text
from sqlalchemy import REAL, Index from sqlalchemy.dialects.postgresql import ARRAY, JSON, TEXT
from sqlalchemy.dialects.postgresql import ARRAY, JSON, UUID from sqlalchemy.engine import Row
try: try:
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
except ImportError: except ImportError:
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
from sqlalchemy.sql.expression import func
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
_LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document"
Base = declarative_base() # type: Any Base = declarative_base() # type: Any
ADA_TOKEN_COUNT = 1536
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
class BaseModel(Base):
__abstract__ = True
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
class CollectionStore(BaseModel):
__tablename__ = "langchain_pg_collection"
name = sqlalchemy.Column(sqlalchemy.String)
cmetadata = sqlalchemy.Column(JSON)
embeddings = relationship(
"EmbeddingStore",
back_populates="collection",
passive_deletes=True,
)
@classmethod
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
return session.query(cls).filter(cls.name == name).first() # type: ignore
@classmethod
def get_or_create(
cls,
session: Session,
name: str,
cmetadata: Optional[dict] = None,
) -> Tuple["CollectionStore", bool]:
"""
Get or create a collection.
Returns [Collection, bool] where the bool is True if the collection was created.
"""
created = False
collection = cls.get_by_name(session, name)
if collection:
return collection, created
collection = cls(name=name, cmetadata=cmetadata)
session.add(collection)
session.commit()
created = True
return collection, created
class EmbeddingStore(BaseModel):
__tablename__ = "langchain_pg_embedding"
collection_id = sqlalchemy.Column(
UUID(as_uuid=True),
sqlalchemy.ForeignKey(
f"{CollectionStore.__tablename__}.uuid",
ondelete="CASCADE",
),
)
collection = relationship(CollectionStore, back_populates="embeddings")
embedding: sqlalchemy.Column = sqlalchemy.Column(ARRAY(REAL))
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
cmetadata = sqlalchemy.Column(JSON, nullable=True)
# custom_id : any user defined id
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
# The following line creates an index named 'langchain_pg_embedding_vector_idx'
langchain_pg_embedding_vector_idx = Index(
"langchain_pg_embedding_vector_idx",
embedding,
postgresql_using="ann",
postgresql_with={
"distancemeasure": "L2",
"dim": 1536,
"pq_segments": 64,
"hnsw_m": 100,
"pq_centers": 2048,
},
)
class QueryResult:
EmbeddingStore: EmbeddingStore
distance: float
class AnalyticDB(VectorStore): class AnalyticDB(VectorStore):
""" """
VectorStore implementation using AnalyticDB. VectorStore implementation using AnalyticDB.
@ -132,15 +46,15 @@ class AnalyticDB(VectorStore):
self, self,
connection_string: str, connection_string: str,
embedding_function: Embeddings, embedding_function: Embeddings,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
collection_metadata: Optional[dict] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
logger: Optional[logging.Logger] = None, logger: Optional[logging.Logger] = None,
) -> None: ) -> None:
self.connection_string = connection_string self.connection_string = connection_string
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.embedding_dimension = embedding_dimension
self.collection_name = collection_name self.collection_name = collection_name
self.collection_metadata = collection_metadata
self.pre_delete_collection = pre_delete_collection self.pre_delete_collection = pre_delete_collection
self.logger = logger or logging.getLogger(__name__) self.logger = logger or logging.getLogger(__name__)
self.__post_init__() self.__post_init__()
@ -151,47 +65,68 @@ class AnalyticDB(VectorStore):
""" """
Initialize the store. Initialize the store.
""" """
self._conn = self.connect() self.engine = create_engine(self.connection_string)
self.create_tables_if_not_exists()
self.create_collection() self.create_collection()
def connect(self) -> sqlalchemy.engine.Connection: def create_table_if_not_exists(self) -> None:
engine = sqlalchemy.create_engine(self.connection_string) # Define the dynamic table
conn = engine.connect() Table(
return conn self.collection_name,
Base.metadata,
Column("id", TEXT, primary_key=True, default=uuid.uuid4),
Column("embedding", ARRAY(REAL)),
Column("document", String, nullable=True),
Column("metadata", JSON, nullable=True),
extend_existing=True,
)
with self.engine.connect() as conn:
# Create the table
Base.metadata.create_all(conn)
def create_tables_if_not_exists(self) -> None: # Check if the index exists
Base.metadata.create_all(self._conn) index_name = f"{self.collection_name}_embedding_idx"
index_query = text(
f"""
SELECT 1
FROM pg_indexes
WHERE indexname = '{index_name}';
"""
)
result = conn.execute(index_query).scalar()
def drop_tables(self) -> None: # Create the index if it doesn't exist
Base.metadata.drop_all(self._conn) if not result:
index_statement = text(
f"""
CREATE INDEX {index_name}
ON {self.collection_name} USING ann(embedding)
WITH (
"dim" = {self.embedding_dimension},
"hnsw_m" = 100
);
"""
)
conn.execute(index_statement)
conn.commit()
def create_collection(self) -> None: def create_collection(self) -> None:
if self.pre_delete_collection: if self.pre_delete_collection:
self.delete_collection() self.delete_collection()
with Session(self._conn) as session: self.create_table_if_not_exists()
CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
def delete_collection(self) -> None: def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection") self.logger.debug("Trying to delete collection")
with Session(self._conn) as session: drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
collection = self.get_collection(session) with self.engine.connect() as conn:
if not collection: conn.execute(drop_statement)
self.logger.error("Collection not found") conn.commit()
return
session.delete(collection)
session.commit()
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
return CollectionStore.get_by_name(session, self.collection_name)
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
batch_size: int = 500,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore. """Run more texts through the embeddings and add to the vectorstore.
@ -212,20 +147,43 @@ class AnalyticDB(VectorStore):
if not metadatas: if not metadatas:
metadatas = [{} for _ in texts] metadatas = [{} for _ in texts]
with Session(self._conn) as session: # Define the table schema
collection = self.get_collection(session) chunks_table = Table(
if not collection: self.collection_name,
raise ValueError("Collection not found") Base.metadata,
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): Column("id", TEXT, primary_key=True),
embedding_store = EmbeddingStore( Column("embedding", ARRAY(REAL)),
embedding=embedding, Column("document", String, nullable=True),
document=text, Column("metadata", JSON, nullable=True),
cmetadata=metadata, extend_existing=True,
custom_id=id, )
chunks_table_data = []
with self.engine.connect() as conn:
for document, metadata, chunk_id, embedding in zip(
texts, metadatas, ids, embeddings
):
chunks_table_data.append(
{
"id": chunk_id,
"embedding": embedding,
"document": document,
"metadata": metadata,
}
) )
collection.embeddings.append(embedding_store)
session.add(embedding_store) # Execute the batch insert when the batch size is reached
session.commit() if len(chunks_table_data) == batch_size:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Clear the chunks_table_data list for the next batch
chunks_table_data.clear()
# Insert any remaining records that didn't make up a full batch
if chunks_table_data:
conn.execute(insert(chunks_table).values(chunks_table_data))
# Commit the transaction only once after all records have been inserted
conn.commit()
return ids return ids
@ -275,52 +233,69 @@ class AnalyticDB(VectorStore):
) )
return docs return docs
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
return self.similarity_search_with_score(query, k, **kwargs)
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
k: int = 4, k: int = 4,
filter: Optional[dict] = None, filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
with Session(self._conn) as session: # Add the filter if provided
collection = self.get_collection(session) filter_condition = ""
if not collection:
raise ValueError("Collection not found")
filter_by = EmbeddingStore.collection_id == collection.uuid
if filter is not None: if filter is not None:
filter_clauses = [] conditions = [
for key, value in filter.items(): f"metadata->>{key!r} = {value!r}" for key, value in filter.items()
filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value) ]
filter_clauses.append(filter_by_metadata) filter_condition = f"WHERE {' AND '.join(conditions)}"
filter_by = sqlalchemy.and_(filter_by, *filter_clauses) # Define the base query
sql_query = f"""
SELECT *, l2_distance(embedding, :embedding) as distance
FROM {self.collection_name}
{filter_condition}
ORDER BY embedding <-> :embedding
LIMIT :k
"""
results: List[QueryResult] = ( # Set up the query parameters
session.query( params = {"embedding": embedding, "k": k}
EmbeddingStore,
func.l2_distance(EmbeddingStore.embedding, embedding).label("distance"), # Execute the query and fetch the results
) with self.engine.connect() as conn:
.filter(filter_by) results: Sequence[Row] = conn.execute(text(sql_query), params).fetchall()
.order_by(EmbeddingStore.embedding.op("<->")(embedding))
.join( documents_with_scores = [
CollectionStore,
EmbeddingStore.collection_id == CollectionStore.uuid,
)
.limit(k)
.all()
)
docs = [
( (
Document( Document(
page_content=result.EmbeddingStore.document, page_content=result.document,
metadata=result.EmbeddingStore.cmetadata, metadata=result.metadata,
), ),
result.distance if self.embedding_function is not None else None, result.distance if self.embedding_function is not None else None,
) )
for result in results for result in results
] ]
return docs return documents_with_scores
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
@ -346,10 +321,11 @@ class AnalyticDB(VectorStore):
@classmethod @classmethod
def from_texts( def from_texts(
cls, cls: Type[AnalyticDB],
texts: List[str], texts: List[str],
embedding: Embeddings, embedding: Embeddings,
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
@ -368,6 +344,7 @@ class AnalyticDB(VectorStore):
connection_string=connection_string, connection_string=connection_string,
collection_name=collection_name, collection_name=collection_name,
embedding_function=embedding, embedding_function=embedding,
embedding_dimension=embedding_dimension,
pre_delete_collection=pre_delete_collection, pre_delete_collection=pre_delete_collection,
) )
@ -379,7 +356,7 @@ class AnalyticDB(VectorStore):
connection_string: str = get_from_dict_or_env( connection_string: str = get_from_dict_or_env(
data=kwargs, data=kwargs,
key="connection_string", key="connection_string",
env_key="PGVECTOR_CONNECTION_STRING", env_key="PG_CONNECTION_STRING",
) )
if not connection_string: if not connection_string:
@ -393,9 +370,10 @@ class AnalyticDB(VectorStore):
@classmethod @classmethod
def from_documents( def from_documents(
cls, cls: Type[AnalyticDB],
documents: List[Document], documents: List[Document],
embedding: Embeddings, embedding: Embeddings,
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
pre_delete_collection: bool = False, pre_delete_collection: bool = False,
@ -418,6 +396,7 @@ class AnalyticDB(VectorStore):
texts=texts, texts=texts,
pre_delete_collection=pre_delete_collection, pre_delete_collection=pre_delete_collection,
embedding=embedding, embedding=embedding,
embedding_dimension=embedding_dimension,
metadatas=metadatas, metadatas=metadatas,
ids=ids, ids=ids,
collection_name=collection_name, collection_name=collection_name,

View File

@ -2,8 +2,6 @@
import os import os
from typing import List from typing import List
from sqlalchemy.orm import Session
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.vectorstores.analyticdb import AnalyticDB from langchain.vectorstores.analyticdb import AnalyticDB
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
@ -11,7 +9,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
CONNECTION_STRING = AnalyticDB.connection_string_from_db_params( CONNECTION_STRING = AnalyticDB.connection_string_from_db_params(
driver=os.environ.get("PG_DRIVER", "psycopg2cffi"), driver=os.environ.get("PG_DRIVER", "psycopg2cffi"),
host=os.environ.get("PG_HOST", "localhost"), host=os.environ.get("PG_HOST", "localhost"),
port=int(os.environ.get("PG_HOST", "5432")), port=int(os.environ.get("PG_PORT", "5432")),
database=os.environ.get("PG_DATABASE", "postgres"), database=os.environ.get("PG_DATABASE", "postgres"),
user=os.environ.get("PG_USER", "postgres"), user=os.environ.get("PG_USER", "postgres"),
password=os.environ.get("PG_PASSWORD", "postgres"), password=os.environ.get("PG_PASSWORD", "postgres"),
@ -128,21 +126,3 @@ def test_analyticdb_with_filter_no_match() -> None:
) )
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
assert output == [] assert output == []
def test_analyticdb_collection_with_metadata() -> None:
"""Test end to end collection construction"""
pgvector = AnalyticDB(
collection_name="test_collection",
collection_metadata={"foo": "bar"},
embedding_function=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
session = Session(pgvector.connect())
collection = pgvector.get_collection(session)
if collection is None:
assert False, "Expected a CollectionStore object but received None"
else:
assert collection.name == "test_collection"
assert collection.cmetadata == {"foo": "bar"}