forked from Archives/langchain
Improve AnalyticDB Vector Store implementation without affecting user (#6086)
Hi there: As I implement the AnalyticDB VectorStore use two table to store the document before. It seems just use one table is a better way. So this commit is try to improve AnalyticDB VectorStore implementation without affecting user behavior: **1. Streamline the `post_init `behavior by creating a single table with vector indexing. 2. Update the `add_texts` API for document insertion. 3. Optimize `similarity_search_with_score_by_vector` to retrieve results directly from the table. 4. Implement `_similarity_search_with_relevance_scores`. 5. Add `embedding_dimension` parameter to support different dimension embedding functions.** Users can continue using the API as before. Test cases added before is enough to meet this commit.
This commit is contained in:
parent
cdd1d78bf2
commit
444ca3f669
@ -3,114 +3,28 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy import REAL, Index
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSON, UUID
|
||||
from sqlalchemy import REAL, Column, String, Table, create_engine, insert, text
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSON, TEXT
|
||||
from sqlalchemy.engine import Row
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, relationship
|
||||
from sqlalchemy.sql.expression import func
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
_LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document"
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
ADA_TOKEN_COUNT = 1536
|
||||
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
||||
|
||||
|
||||
class BaseModel(Base):
|
||||
__abstract__ = True
|
||||
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
|
||||
class CollectionStore(BaseModel):
|
||||
__tablename__ = "langchain_pg_collection"
|
||||
|
||||
name = sqlalchemy.Column(sqlalchemy.String)
|
||||
cmetadata = sqlalchemy.Column(JSON)
|
||||
|
||||
embeddings = relationship(
|
||||
"EmbeddingStore",
|
||||
back_populates="collection",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]:
|
||||
return session.query(cls).filter(cls.name == name).first() # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_or_create(
|
||||
cls,
|
||||
session: Session,
|
||||
name: str,
|
||||
cmetadata: Optional[dict] = None,
|
||||
) -> Tuple["CollectionStore", bool]:
|
||||
"""
|
||||
Get or create a collection.
|
||||
Returns [Collection, bool] where the bool is True if the collection was created.
|
||||
"""
|
||||
created = False
|
||||
collection = cls.get_by_name(session, name)
|
||||
if collection:
|
||||
return collection, created
|
||||
|
||||
collection = cls(name=name, cmetadata=cmetadata)
|
||||
session.add(collection)
|
||||
session.commit()
|
||||
created = True
|
||||
return collection, created
|
||||
|
||||
|
||||
class EmbeddingStore(BaseModel):
|
||||
__tablename__ = "langchain_pg_embedding"
|
||||
|
||||
collection_id = sqlalchemy.Column(
|
||||
UUID(as_uuid=True),
|
||||
sqlalchemy.ForeignKey(
|
||||
f"{CollectionStore.__tablename__}.uuid",
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
collection = relationship(CollectionStore, back_populates="embeddings")
|
||||
|
||||
embedding: sqlalchemy.Column = sqlalchemy.Column(ARRAY(REAL))
|
||||
document = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
cmetadata = sqlalchemy.Column(JSON, nullable=True)
|
||||
|
||||
# custom_id : any user defined id
|
||||
custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
|
||||
# The following line creates an index named 'langchain_pg_embedding_vector_idx'
|
||||
langchain_pg_embedding_vector_idx = Index(
|
||||
"langchain_pg_embedding_vector_idx",
|
||||
embedding,
|
||||
postgresql_using="ann",
|
||||
postgresql_with={
|
||||
"distancemeasure": "L2",
|
||||
"dim": 1536,
|
||||
"pq_segments": 64,
|
||||
"hnsw_m": 100,
|
||||
"pq_centers": 2048,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class QueryResult:
|
||||
EmbeddingStore: EmbeddingStore
|
||||
distance: float
|
||||
|
||||
|
||||
class AnalyticDB(VectorStore):
|
||||
"""
|
||||
VectorStore implementation using AnalyticDB.
|
||||
@ -132,15 +46,15 @@ class AnalyticDB(VectorStore):
|
||||
self,
|
||||
connection_string: str,
|
||||
embedding_function: Embeddings,
|
||||
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self.connection_string = connection_string
|
||||
self.embedding_function = embedding_function
|
||||
self.embedding_dimension = embedding_dimension
|
||||
self.collection_name = collection_name
|
||||
self.collection_metadata = collection_metadata
|
||||
self.pre_delete_collection = pre_delete_collection
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.__post_init__()
|
||||
@ -151,47 +65,68 @@ class AnalyticDB(VectorStore):
|
||||
"""
|
||||
Initialize the store.
|
||||
"""
|
||||
self._conn = self.connect()
|
||||
self.create_tables_if_not_exists()
|
||||
self.engine = create_engine(self.connection_string)
|
||||
self.create_collection()
|
||||
|
||||
def connect(self) -> sqlalchemy.engine.Connection:
|
||||
engine = sqlalchemy.create_engine(self.connection_string)
|
||||
conn = engine.connect()
|
||||
return conn
|
||||
def create_table_if_not_exists(self) -> None:
|
||||
# Define the dynamic table
|
||||
Table(
|
||||
self.collection_name,
|
||||
Base.metadata,
|
||||
Column("id", TEXT, primary_key=True, default=uuid.uuid4),
|
||||
Column("embedding", ARRAY(REAL)),
|
||||
Column("document", String, nullable=True),
|
||||
Column("metadata", JSON, nullable=True),
|
||||
extend_existing=True,
|
||||
)
|
||||
with self.engine.connect() as conn:
|
||||
# Create the table
|
||||
Base.metadata.create_all(conn)
|
||||
|
||||
def create_tables_if_not_exists(self) -> None:
|
||||
Base.metadata.create_all(self._conn)
|
||||
# Check if the index exists
|
||||
index_name = f"{self.collection_name}_embedding_idx"
|
||||
index_query = text(
|
||||
f"""
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE indexname = '{index_name}';
|
||||
"""
|
||||
)
|
||||
result = conn.execute(index_query).scalar()
|
||||
|
||||
def drop_tables(self) -> None:
|
||||
Base.metadata.drop_all(self._conn)
|
||||
# Create the index if it doesn't exist
|
||||
if not result:
|
||||
index_statement = text(
|
||||
f"""
|
||||
CREATE INDEX {index_name}
|
||||
ON {self.collection_name} USING ann(embedding)
|
||||
WITH (
|
||||
"dim" = {self.embedding_dimension},
|
||||
"hnsw_m" = 100
|
||||
);
|
||||
"""
|
||||
)
|
||||
conn.execute(index_statement)
|
||||
conn.commit()
|
||||
|
||||
def create_collection(self) -> None:
|
||||
if self.pre_delete_collection:
|
||||
self.delete_collection()
|
||||
with Session(self._conn) as session:
|
||||
CollectionStore.get_or_create(
|
||||
session, self.collection_name, cmetadata=self.collection_metadata
|
||||
)
|
||||
self.create_table_if_not_exists()
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
self.logger.debug("Trying to delete collection")
|
||||
with Session(self._conn) as session:
|
||||
collection = self.get_collection(session)
|
||||
if not collection:
|
||||
self.logger.error("Collection not found")
|
||||
return
|
||||
session.delete(collection)
|
||||
session.commit()
|
||||
|
||||
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
|
||||
return CollectionStore.get_by_name(session, self.collection_name)
|
||||
drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};")
|
||||
with self.engine.connect() as conn:
|
||||
conn.execute(drop_statement)
|
||||
conn.commit()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 500,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
@ -212,20 +147,43 @@ class AnalyticDB(VectorStore):
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
|
||||
with Session(self._conn) as session:
|
||||
collection = self.get_collection(session)
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
||||
embedding_store = EmbeddingStore(
|
||||
embedding=embedding,
|
||||
document=text,
|
||||
cmetadata=metadata,
|
||||
custom_id=id,
|
||||
# Define the table schema
|
||||
chunks_table = Table(
|
||||
self.collection_name,
|
||||
Base.metadata,
|
||||
Column("id", TEXT, primary_key=True),
|
||||
Column("embedding", ARRAY(REAL)),
|
||||
Column("document", String, nullable=True),
|
||||
Column("metadata", JSON, nullable=True),
|
||||
extend_existing=True,
|
||||
)
|
||||
|
||||
chunks_table_data = []
|
||||
with self.engine.connect() as conn:
|
||||
for document, metadata, chunk_id, embedding in zip(
|
||||
texts, metadatas, ids, embeddings
|
||||
):
|
||||
chunks_table_data.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"embedding": embedding,
|
||||
"document": document,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
collection.embeddings.append(embedding_store)
|
||||
session.add(embedding_store)
|
||||
session.commit()
|
||||
|
||||
# Execute the batch insert when the batch size is reached
|
||||
if len(chunks_table_data) == batch_size:
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
# Clear the chunks_table_data list for the next batch
|
||||
chunks_table_data.clear()
|
||||
|
||||
# Insert any remaining records that didn't make up a full batch
|
||||
if chunks_table_data:
|
||||
conn.execute(insert(chunks_table).values(chunks_table_data))
|
||||
|
||||
# Commit the transaction only once after all records have been inserted
|
||||
conn.commit()
|
||||
|
||||
return ids
|
||||
|
||||
@ -275,52 +233,69 @@ class AnalyticDB(VectorStore):
|
||||
)
|
||||
return docs
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
with Session(self._conn) as session:
|
||||
collection = self.get_collection(session)
|
||||
if not collection:
|
||||
raise ValueError("Collection not found")
|
||||
|
||||
filter_by = EmbeddingStore.collection_id == collection.uuid
|
||||
|
||||
# Add the filter if provided
|
||||
filter_condition = ""
|
||||
if filter is not None:
|
||||
filter_clauses = []
|
||||
for key, value in filter.items():
|
||||
filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value)
|
||||
filter_clauses.append(filter_by_metadata)
|
||||
conditions = [
|
||||
f"metadata->>{key!r} = {value!r}" for key, value in filter.items()
|
||||
]
|
||||
filter_condition = f"WHERE {' AND '.join(conditions)}"
|
||||
|
||||
filter_by = sqlalchemy.and_(filter_by, *filter_clauses)
|
||||
# Define the base query
|
||||
sql_query = f"""
|
||||
SELECT *, l2_distance(embedding, :embedding) as distance
|
||||
FROM {self.collection_name}
|
||||
{filter_condition}
|
||||
ORDER BY embedding <-> :embedding
|
||||
LIMIT :k
|
||||
"""
|
||||
|
||||
results: List[QueryResult] = (
|
||||
session.query(
|
||||
EmbeddingStore,
|
||||
func.l2_distance(EmbeddingStore.embedding, embedding).label("distance"),
|
||||
)
|
||||
.filter(filter_by)
|
||||
.order_by(EmbeddingStore.embedding.op("<->")(embedding))
|
||||
.join(
|
||||
CollectionStore,
|
||||
EmbeddingStore.collection_id == CollectionStore.uuid,
|
||||
)
|
||||
.limit(k)
|
||||
.all()
|
||||
)
|
||||
docs = [
|
||||
# Set up the query parameters
|
||||
params = {"embedding": embedding, "k": k}
|
||||
|
||||
# Execute the query and fetch the results
|
||||
with self.engine.connect() as conn:
|
||||
results: Sequence[Row] = conn.execute(text(sql_query), params).fetchall()
|
||||
|
||||
documents_with_scores = [
|
||||
(
|
||||
Document(
|
||||
page_content=result.EmbeddingStore.document,
|
||||
metadata=result.EmbeddingStore.cmetadata,
|
||||
page_content=result.document,
|
||||
metadata=result.metadata,
|
||||
),
|
||||
result.distance if self.embedding_function is not None else None,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
return docs
|
||||
return documents_with_scores
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
@ -346,10 +321,11 @@ class AnalyticDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
cls: Type[AnalyticDB],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
@ -368,6 +344,7 @@ class AnalyticDB(VectorStore):
|
||||
connection_string=connection_string,
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding,
|
||||
embedding_dimension=embedding_dimension,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
|
||||
@ -379,7 +356,7 @@ class AnalyticDB(VectorStore):
|
||||
connection_string: str = get_from_dict_or_env(
|
||||
data=kwargs,
|
||||
key="connection_string",
|
||||
env_key="PGVECTOR_CONNECTION_STRING",
|
||||
env_key="PG_CONNECTION_STRING",
|
||||
)
|
||||
|
||||
if not connection_string:
|
||||
@ -393,9 +370,10 @@ class AnalyticDB(VectorStore):
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls,
|
||||
cls: Type[AnalyticDB],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
ids: Optional[List[str]] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
@ -418,6 +396,7 @@ class AnalyticDB(VectorStore):
|
||||
texts=texts,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
embedding=embedding,
|
||||
embedding_dimension=embedding_dimension,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
collection_name=collection_name,
|
||||
|
@ -2,8 +2,6 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.analyticdb import AnalyticDB
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
@ -11,7 +9,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
CONNECTION_STRING = AnalyticDB.connection_string_from_db_params(
|
||||
driver=os.environ.get("PG_DRIVER", "psycopg2cffi"),
|
||||
host=os.environ.get("PG_HOST", "localhost"),
|
||||
port=int(os.environ.get("PG_HOST", "5432")),
|
||||
port=int(os.environ.get("PG_PORT", "5432")),
|
||||
database=os.environ.get("PG_DATABASE", "postgres"),
|
||||
user=os.environ.get("PG_USER", "postgres"),
|
||||
password=os.environ.get("PG_PASSWORD", "postgres"),
|
||||
@ -128,21 +126,3 @@ def test_analyticdb_with_filter_no_match() -> None:
|
||||
)
|
||||
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"})
|
||||
assert output == []
|
||||
|
||||
|
||||
def test_analyticdb_collection_with_metadata() -> None:
|
||||
"""Test end to end collection construction"""
|
||||
pgvector = AnalyticDB(
|
||||
collection_name="test_collection",
|
||||
collection_metadata={"foo": "bar"},
|
||||
embedding_function=FakeEmbeddingsWithAdaDimension(),
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
session = Session(pgvector.connect())
|
||||
collection = pgvector.get_collection(session)
|
||||
if collection is None:
|
||||
assert False, "Expected a CollectionStore object but received None"
|
||||
else:
|
||||
assert collection.name == "test_collection"
|
||||
assert collection.cmetadata == {"foo": "bar"}
|
||||
|
Loading…
Reference in New Issue
Block a user