From ec8247ec5979ff5f8f31f88a797a3e8dfa48b915 Mon Sep 17 00:00:00 2001 From: Richy Wang Date: Mon, 26 Jun 2023 20:35:25 +0800 Subject: [PATCH] Fixed bug in AnalyticDB Vector Store caused by upgrade SQLAlchemy version (#6736) --- langchain/vectorstores/analyticdb.py | 104 +++++++++++++-------------- 1 file changed, 51 insertions(+), 53 deletions(-) diff --git a/langchain/vectorstores/analyticdb.py b/langchain/vectorstores/analyticdb.py index 5d422a3beb..385d666f1e 100644 --- a/langchain/vectorstores/analyticdb.py +++ b/langchain/vectorstores/analyticdb.py @@ -80,34 +80,34 @@ class AnalyticDB(VectorStore): extend_existing=True, ) with self.engine.connect() as conn: - # Create the table - Base.metadata.create_all(conn) + with conn.begin(): + # Create the table + Base.metadata.create_all(conn) - # Check if the index exists - index_name = f"{self.collection_name}_embedding_idx" - index_query = text( - f""" - SELECT 1 - FROM pg_indexes - WHERE indexname = '{index_name}'; - """ - ) - result = conn.execute(index_query).scalar() - - # Create the index if it doesn't exist - if not result: - index_statement = text( + # Check if the index exists + index_name = f"{self.collection_name}_embedding_idx" + index_query = text( f""" - CREATE INDEX {index_name} - ON {self.collection_name} USING ann(embedding) - WITH ( - "dim" = {self.embedding_dimension}, - "hnsw_m" = 100 - ); + SELECT 1 + FROM pg_indexes + WHERE indexname = '{index_name}'; """ ) - conn.execute(index_statement) - conn.commit() + result = conn.execute(index_query).scalar() + + # Create the index if it doesn't exist + if not result: + index_statement = text( + f""" + CREATE INDEX {index_name} + ON {self.collection_name} USING ann(embedding) + WITH ( + "dim" = {self.embedding_dimension}, + "hnsw_m" = 100 + ); + """ + ) + conn.execute(index_statement) def create_collection(self) -> None: if self.pre_delete_collection: @@ -118,8 +118,8 @@ class AnalyticDB(VectorStore): self.logger.debug("Trying to delete collection") drop_statement = text(f"DROP TABLE IF EXISTS {self.collection_name};") with self.engine.connect() as conn: - conn.execute(drop_statement) - conn.commit() + with conn.begin(): + conn.execute(drop_statement) def add_texts( self, @@ -160,30 +160,28 @@ class AnalyticDB(VectorStore): chunks_table_data = [] with self.engine.connect() as conn: - for document, metadata, chunk_id, embedding in zip( - texts, metadatas, ids, embeddings - ): - chunks_table_data.append( - { - "id": chunk_id, - "embedding": embedding, - "document": document, - "metadata": metadata, - } - ) + with conn.begin(): + for document, metadata, chunk_id, embedding in zip( + texts, metadatas, ids, embeddings + ): + chunks_table_data.append( + { + "id": chunk_id, + "embedding": embedding, + "document": document, + "metadata": metadata, + } + ) - # Execute the batch insert when the batch size is reached - if len(chunks_table_data) == batch_size: + # Execute the batch insert when the batch size is reached + if len(chunks_table_data) == batch_size: + conn.execute(insert(chunks_table).values(chunks_table_data)) + # Clear the chunks_table_data list for the next batch + chunks_table_data.clear() + + # Insert any remaining records that didn't make up a full batch + if chunks_table_data: conn.execute(insert(chunks_table).values(chunks_table_data)) - # Clear the chunks_table_data list for the next batch - chunks_table_data.clear() - - # Insert any remaining records that didn't make up a full batch - if chunks_table_data: - conn.execute(insert(chunks_table).values(chunks_table_data)) - - # Commit the transaction only once after all records have been inserted - conn.commit() return ids @@ -333,9 +331,9 @@ class AnalyticDB(VectorStore): ) -> AnalyticDB: """ Return VectorStore initialized from texts and embeddings. - Postgres connection string is required + Postgres Connection string is required Either pass it as a parameter - or set the PGVECTOR_CONNECTION_STRING environment variable. + or set the PG_CONNECTION_STRING environment variable. """ connection_string = cls.get_connection_string(kwargs) @@ -363,7 +361,7 @@ class AnalyticDB(VectorStore): raise ValueError( "Postgres connection string is required" "Either pass it as a parameter" - "or set the PGVECTOR_CONNECTION_STRING environment variable." + "or set the PG_CONNECTION_STRING environment variable." ) return connection_string @@ -381,9 +379,9 @@ class AnalyticDB(VectorStore): ) -> AnalyticDB: """ Return VectorStore initialized from documents and embeddings. - Postgres connection string is required + Postgres Connection string is required Either pass it as a parameter - or set the PGVECTOR_CONNECTION_STRING environment variable. + or set the PG_CONNECTION_STRING environment variable. """ texts = [d.page_content for d in documents]