diff --git a/langchain/vectorstores/analyticdb.py b/langchain/vectorstores/analyticdb.py index d8986aeffa..86917b628f 100644 --- a/langchain/vectorstores/analyticdb.py +++ b/langchain/vectorstores/analyticdb.py @@ -49,6 +49,7 @@ class AnalyticDB(VectorStore): collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, pre_delete_collection: bool = False, logger: Optional[logging.Logger] = None, + engine_args: Optional[dict] = None, ) -> None: self.connection_string = connection_string self.embedding_function = embedding_function @@ -56,15 +57,26 @@ class AnalyticDB(VectorStore): self.collection_name = collection_name self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) - self.__post_init__() + self.__post_init__(engine_args) def __post_init__( self, + engine_args: Optional[dict] = None, ) -> None: """ Initialize the store. """ - self.engine = create_engine(self.connection_string) + + _engine_args = engine_args or {} + + if ( + "pool_recycle" not in _engine_args + ): # Check if pool_recycle is not in _engine_args + _engine_args[ + "pool_recycle" + ] = 3600 # Set pool_recycle to 3600s if not present + + self.engine = create_engine(self.connection_string, **_engine_args) self.create_collection() def create_table_if_not_exists(self) -> None: @@ -324,6 +336,36 @@ class AnalyticDB(VectorStore): ) return [doc for doc, _ in docs_and_scores] + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: + """Delete by vector IDs. + + Args: + ids: List of ids to delete. + """ + if ids is None: + raise ValueError("No ids provided to delete.") + + # Define the table schema + chunks_table = Table( + self.collection_name, + Base.metadata, + Column("id", TEXT, primary_key=True), + Column("embedding", ARRAY(REAL)), + Column("document", String, nullable=True), + Column("metadata", JSON, nullable=True), + extend_existing=True, + ) + + try: + with self.engine.connect() as conn: + with conn.begin(): + delete_condition = chunks_table.c.id.in_(ids) + conn.execute(chunks_table.delete().where(delete_condition)) + return True + except Exception as e: + print("Delete operation failed:", str(e)) + return False + @classmethod def from_texts( cls: Type[AnalyticDB], @@ -334,6 +376,7 @@ class AnalyticDB(VectorStore): collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, + engine_args: Optional[dict] = None, **kwargs: Any, ) -> AnalyticDB: """ @@ -351,6 +394,7 @@ class AnalyticDB(VectorStore): embedding_function=embedding, embedding_dimension=embedding_dimension, pre_delete_collection=pre_delete_collection, + engine_args=engine_args, ) store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs) @@ -382,6 +426,7 @@ class AnalyticDB(VectorStore): collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, ids: Optional[List[str]] = None, pre_delete_collection: bool = False, + engine_args: Optional[dict] = None, **kwargs: Any, ) -> AnalyticDB: """ @@ -405,6 +450,7 @@ class AnalyticDB(VectorStore): metadatas=metadatas, ids=ids, collection_name=collection_name, + engine_args=engine_args, **kwargs, ) diff --git a/tests/integration_tests/vectorstores/test_analyticdb.py b/tests/integration_tests/vectorstores/test_analyticdb.py index 1149b2259b..c8ce1c9bb0 100644 --- a/tests/integration_tests/vectorstores/test_analyticdb.py +++ b/tests/integration_tests/vectorstores/test_analyticdb.py @@ -47,6 +47,22 @@ def test_analyticdb() -> None: assert output == [Document(page_content="foo")] +def test_analyticdb_with_engine_args() -> None: + engine_args = {"pool_recycle": 3600, "pool_size": 50} + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + engine_args=engine_args, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_analyticdb_with_metadatas() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -126,3 +142,25 @@ def test_analyticdb_with_filter_no_match() -> None: ) output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "5"}) assert output == [] + + +def test_analyticdb_delete() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + ids = ["fooid", "barid", "bazid"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = AnalyticDB.from_texts( + texts=texts, + collection_name="test_collection_delete", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection_string=CONNECTION_STRING, + ids=ids, + pre_delete_collection=True, + ) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) + print(output) + assert output == [(Document(page_content="baz", metadata={"page": "2"}), 4.0)] + docsearch.delete(ids=ids) + output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"}) + assert output == []