diff --git a/langchain/vectorstores/analyticdb.py b/langchain/vectorstores/analyticdb.py index 47460ad2..b93172a9 100644 --- a/langchain/vectorstores/analyticdb.py +++ b/langchain/vectorstores/analyticdb.py @@ -3,112 +3,26 @@ from __future__ import annotations import logging import uuid -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type -import sqlalchemy -from sqlalchemy import REAL, Index -from sqlalchemy.dialects.postgresql import ARRAY, JSON, UUID +from sqlalchemy import REAL, Column, String, Table, create_engine, insert, text +from sqlalchemy.dialects.postgresql import ARRAY, JSON, TEXT +from sqlalchemy.engine import Row try: from sqlalchemy.orm import declarative_base except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, relationship -from sqlalchemy.sql.expression import func from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore -Base = declarative_base() # type: Any - - -ADA_TOKEN_COUNT = 1536 -_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" - - -class BaseModel(Base): - __abstract__ = True - uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - - -class CollectionStore(BaseModel): - __tablename__ = "langchain_pg_collection" +_LANGCHAIN_DEFAULT_EMBEDDING_DIM = 1536 +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain_document" - name = sqlalchemy.Column(sqlalchemy.String) - cmetadata = sqlalchemy.Column(JSON) - - embeddings = relationship( - "EmbeddingStore", - back_populates="collection", - passive_deletes=True, - ) - - @classmethod - def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: - return session.query(cls).filter(cls.name == name).first() # type: ignore - - @classmethod - def get_or_create( - cls, - session: Session, - name: str, - cmetadata: Optional[dict] = None, - ) -> Tuple["CollectionStore", bool]: - """ - Get or create a collection. - Returns [Collection, bool] where the bool is True if the collection was created. - """ - created = False - collection = cls.get_by_name(session, name) - if collection: - return collection, created - - collection = cls(name=name, cmetadata=cmetadata) - session.add(collection) - session.commit() - created = True - return collection, created - - -class EmbeddingStore(BaseModel): - __tablename__ = "langchain_pg_embedding" - - collection_id = sqlalchemy.Column( - UUID(as_uuid=True), - sqlalchemy.ForeignKey( - f"{CollectionStore.__tablename__}.uuid", - ondelete="CASCADE", - ), - ) - collection = relationship(CollectionStore, back_populates="embeddings") - - embedding: sqlalchemy.Column = sqlalchemy.Column(ARRAY(REAL)) - document = sqlalchemy.Column(sqlalchemy.String, nullable=True) - cmetadata = sqlalchemy.Column(JSON, nullable=True) - - # custom_id : any user defined id - custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) - - # The following line creates an index named 'langchain_pg_embedding_vector_idx' - langchain_pg_embedding_vector_idx = Index( - "langchain_pg_embedding_vector_idx", - embedding, - postgresql_using="ann", - postgresql_with={ - "distancemeasure": "L2", - "dim": 1536, - "pq_segments": 64, - "hnsw_m": 100, - "pq_centers": 2048, - }, - ) - - -class QueryResult: - EmbeddingStore: EmbeddingStore - distance: float +Base = declarative_base() # type: Any class AnalyticDB(VectorStore): @@ -132,15 +46,15 @@ class AnalyticDB(VectorStore): self, connection_string: str, embedding_function: Embeddings, + embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, - collection_metadata: Optional[dict] = None, pre_delete_collection: bool = False, logger: Optional[logging.Logger] = None, ) -> None: self.connection_string = connection_string self.embedding_function = embedding_function + self.embedding_dimension = embedding_dimension self.collection_name = collection_name - self.collection_metadata = collection_metadata self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) self.__post_init__() @@ -151,47 +65,68 @@ class AnalyticDB(VectorStore): """ Initialize the store. """ - self._conn = self.connect() - self.create_tables_if_not_exists() + self.engine = create_engine(self.connection_string) self.create_collection() - def connect(self) -> sqlalchemy.engine.Connection: - engine = sqlalchemy.create_engine(self.connection_string) - conn = engine.connect() - return conn - - def create_tables_if_not_exists(self) -> None: - Base.metadata.create_all(self._conn) - - def drop_tables(self) -> None: - Base.metadata.drop_all(self._conn) + def create_table_if_not_exists(self) -> None: + # Define the dynamic table + Table( + self.collection_name, + Base.metadata, + Column("id", TEXT, primary_key=True, default=uuid.uuid4), + Column("embedding", ARRAY(REAL)), + Column("document", String, nullable=True), + Column("metadata", JSON, nullable=True), + extend_existing=True, + ) + with self.engine.connect() as conn: + # Create the table + Base.metadata.create_all(conn) + + # Check if the index exists + index_name = f"{self.collection_name}_embedding_idx" + index_query = text( + f""" + SELECT 1 + FROM pg_indexes + WHERE indexname = '{index_name}'; + """ + ) + result = conn.execute(index_query).scalar() + + # Create the index if it doesn't exist + if not result: + index_statement = text( + f""" + CREATE INDEX {index_name} + ON {self.collection_name} USING ann(embedding) + WITH ( + "dim" = {self.embedding_dimension}, + "hnsw_m" = 100 + ); + """ + ) + conn.execute(index_statement) + conn.commit() def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() - with Session(self._conn) as session: - CollectionStore.get_or_create( - session, self.collection_name, cmetadata=self.collection_metadata - ) + self.create_table_if_not_exists() def delete_collection(self) -> None: self.logger.debug("Trying to delete collection") - with Session(self._conn) as session: - collection = self.get_collection(session) - if not collection: - self.logger.error("Collection not found") - return - session.delete(collection) - session.commit() - - def get_collection(self, session: Session) -> Optional["CollectionStore"]: - return CollectionStore.get_by_name(session, self.collection_name) + drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};") + with self.engine.connect() as conn: + conn.execute(drop_statement) + conn.commit() def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, + batch_size: int = 500, **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore. @@ -212,20 +147,43 @@ class AnalyticDB(VectorStore): if not metadatas: metadatas = [{} for _ in texts] - with Session(self._conn) as session: - collection = self.get_collection(session) - if not collection: - raise ValueError("Collection not found") - for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): - embedding_store = EmbeddingStore( - embedding=embedding, - document=text, - cmetadata=metadata, - custom_id=id, + # Define the table schema + chunks_table = Table( + self.collection_name, + Base.metadata, + Column("id", TEXT, primary_key=True), + Column("embedding", ARRAY(REAL)), + Column("document", String, nullable=True), + Column("metadata", JSON, nullable=True), + extend_existing=True, + ) + + chunks_table_data = [] + with self.engine.connect() as conn: + for document, metadata, chunk_id, embedding in zip( + texts, metadatas, ids, embeddings + ): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } ) - collection.embeddings.append(embedding_store) - session.add(embedding_store) - session.commit() + + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == batch_size: + conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: + conn.execute(insert(chunks_table).values(chunks_table_data)) + + # Commit the transaction only once after all records have been inserted + conn.commit() return ids @@ -275,52 +233,69 @@ class AnalyticDB(VectorStore): ) return docs + def _similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs and relevance scores in the range [0, 1]. + + 0 is dissimilar, 1 is most similar. + + Args: + query: input text + k: Number of Documents to return. Defaults to 4. + **kwargs: kwargs to be passed to similarity search. Should include: + score_threshold: Optional, a floating point value between 0 to 1 to + filter the resulting set of retrieved docs + + Returns: + List of Tuples of (doc, similarity_score) + """ + return self.similarity_search_with_score(query, k, **kwargs) + def similarity_search_with_score_by_vector( self, embedding: List[float], k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - with Session(self._conn) as session: - collection = self.get_collection(session) - if not collection: - raise ValueError("Collection not found") - - filter_by = EmbeddingStore.collection_id == collection.uuid - + # Add the filter if provided + filter_condition = "" if filter is not None: - filter_clauses = [] - for key, value in filter.items(): - filter_by_metadata = EmbeddingStore.cmetadata[key].astext == str(value) - filter_clauses.append(filter_by_metadata) + conditions = [ + f"metadata->>{key!r} = {value!r}" for key, value in filter.items() + ] + filter_condition = f"WHERE {' AND '.join(conditions)}" + + # Define the base query + sql_query = f""" + SELECT *, l2_distance(embedding, :embedding) as distance + FROM {self.collection_name} + {filter_condition} + ORDER BY embedding <-> :embedding + LIMIT :k + """ - filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + # Set up the query parameters + params = {"embedding": embedding, "k": k} - results: List[QueryResult] = ( - session.query( - EmbeddingStore, - func.l2_distance(EmbeddingStore.embedding, embedding).label("distance"), - ) - .filter(filter_by) - .order_by(EmbeddingStore.embedding.op("<->")(embedding)) - .join( - CollectionStore, - EmbeddingStore.collection_id == CollectionStore.uuid, - ) - .limit(k) - .all() - ) - docs = [ + # Execute the query and fetch the results + with self.engine.connect() as conn: + results: Sequence[Row] = conn.execute(text(sql_query), params).fetchall() + + documents_with_scores = [ ( Document( - page_content=result.EmbeddingStore.document, - metadata=result.EmbeddingStore.cmetadata, + page_content=result.document, + metadata=result.metadata, ), result.distance if self.embedding_function is not None else None, ) for result in results ] - return docs + return documents_with_scores def similarity_search_by_vector( self, @@ -346,10 +321,11 @@ class AnalyticDB(VectorStore): @classmethod def from_texts( - cls, + cls: Type[AnalyticDB], texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, @@ -368,6 +344,7 @@ class AnalyticDB(VectorStore): connection_string=connection_string, collection_name=collection_name, embedding_function=embedding, + embedding_dimension=embedding_dimension, pre_delete_collection=pre_delete_collection, ) @@ -379,7 +356,7 @@ class AnalyticDB(VectorStore): connection_string: str = get_from_dict_or_env( data=kwargs, key="connection_string", - env_key="PGVECTOR_CONNECTION_STRING", + env_key="PG_CONNECTION_STRING", ) if not connection_string: @@ -393,9 +370,10 @@ class AnalyticDB(VectorStore): @classmethod def from_documents( - cls, + cls: Type[AnalyticDB], documents: List[Document], embedding: Embeddings, + embedding_dimension: int = _LANGCHAIN_DEFAULT_EMBEDDING_DIM, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, @@ -418,6 +396,7 @@ class AnalyticDB(VectorStore): texts=texts, pre_delete_collection=pre_delete_collection, embedding=embedding, + embedding_dimension=embedding_dimension, metadatas=metadatas, ids=ids, collection_name=collection_name, diff --git a/tests/integration_tests/vectorstores/test_analyticdb.py b/tests/integration_tests/vectorstores/test_analyticdb.py index d3bbe0e6..1149b225 100644 --- a/tests/integration_tests/vectorstores/test_analyticdb.py +++ b/tests/integration_tests/vectorstores/test_analyticdb.py @@ -2,8 +2,6 @@ import os from typing import List -from sqlalchemy.orm import Session - from langchain.docstore.document import Document from langchain.vectorstores.analyticdb import AnalyticDB from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings @@ -11,7 +9,7 @@ from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings CONNECTION_STRING = AnalyticDB.connection_string_from_db_params( driver=os.environ.get("PG_DRIVER", "psycopg2cffi"), host=os.environ.get("PG_HOST", "localhost"), - port=int(os.environ.get("PG_HOST", "5432")), + port=int(os.environ.get("PG_PORT", "5432")), database=os.environ.get("PG_DATABASE", "postgres"), user=os.environ.get("PG_USER", "postgres"), password=os.environ.get("PG_PASSWORD", "postgres"), @@ -128,21 +126,3 @@ def test_analyticdb_with_filter_no_match() -> None: ) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] - - -def test_analyticdb_collection_with_metadata() -> None: - """Test end to end collection construction""" - pgvector = AnalyticDB( - collection_name="test_collection", - collection_metadata={"foo": "bar"}, - embedding_function=FakeEmbeddingsWithAdaDimension(), - connection_string=CONNECTION_STRING, - pre_delete_collection=True, - ) - session = Session(pgvector.connect()) - collection = pgvector.get_collection(session) - if collection is None: - assert False, "Expected a CollectionStore object but received None" - else: - assert collection.name == "test_collection" - assert collection.cmetadata == {"foo": "bar"}