From ddf4e7c633eb7833ad3278507b5e7cbba280f7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=90=E7=B2=92=20Yanli?= Date: Tue, 16 Jan 2024 03:41:59 +0800 Subject: [PATCH] community[minor]: Update pgvecto_rs to use its high level sdk (#15574) - **Description:** Update pgvecto_rs to use its high level sdk, - **Issue:** fix #15173 --- .../vectorstores/pgvecto_rs.ipynb | 59 +++++++-- .../vectorstores/pgvecto_rs.py | 123 ++++++++---------- 2 files changed, 102 insertions(+), 80 deletions(-) diff --git a/docs/docs/integrations/vectorstores/pgvecto_rs.ipynb b/docs/docs/integrations/vectorstores/pgvecto_rs.ipynb index e216d530fc..e72aefe9ec 100644 --- a/docs/docs/integrations/vectorstores/pgvecto_rs.ipynb +++ b/docs/docs/integrations/vectorstores/pgvecto_rs.ipynb @@ -6,7 +6,7 @@ "source": [ "# PGVecto.rs\n", "\n", - "This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs)). You need to install SQLAlchemy >= 2 manually." + "This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs))." ] }, { @@ -15,10 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "## Loading Environment Variables\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()" + "%pip install \"pgvecto_rs[sdk]\"" ] }, { @@ -32,8 +29,8 @@ "from langchain.docstore.document import Document\n", "from langchain.text_splitter import CharacterTextSplitter\n", "from langchain_community.document_loaders import TextLoader\n", - "from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs\n", - "from langchain_openai import OpenAIEmbeddings" + "from langchain_community.embeddings.fake import FakeEmbeddings\n", + "from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs" ] }, { @@ -42,12 +39,12 @@ "metadata": {}, "outputs": [], "source": [ - "loader = TextLoader(\"../../../state_of_the_union.txt\")\n", + "loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n", "documents = loader.load()\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "docs = text_splitter.split_documents(documents)\n", "\n", - "embeddings = OpenAIEmbeddings()" + "embeddings = FakeEmbeddings(size=3)" ] }, { @@ -176,7 +173,42 @@ "outputs": [], "source": [ "query = \"What did the president say about Ketanji Brown Jackson\"\n", - "docs: List[Document] = db1.similarity_search(query, k=4)" + "docs: List[Document] = db1.similarity_search(query, k=4)\n", + "for doc in docs:\n", + " print(doc.page_content)\n", + " print(\"======================\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Similarity Search with Filter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pgvecto_rs.sdk.filters import meta_contains\n", + "\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs: List[Document] = db1.similarity_search(\n", + " query, k=4, filter=meta_contains({\"source\": \"../../modules/state_of_the_union.txt\"})\n", + ")\n", + "\n", + "for doc in docs:\n", + " print(doc.page_content)\n", + " print(\"======================\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Or:" ] }, { @@ -185,6 +217,11 @@ "metadata": {}, "outputs": [], "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs: List[Document] = db1.similarity_search(\n", + " query, k=4, filter={\"source\": \"../../modules/state_of_the_union.txt\"}\n", + ")\n", + "\n", "for doc in docs:\n", " print(doc.page_content)\n", " print(\"======================\")" @@ -207,7 +244,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.6" } }, "nbformat": 4, diff --git a/libs/community/langchain_community/vectorstores/pgvecto_rs.py b/libs/community/langchain_community/vectorstores/pgvecto_rs.py index 6c4a887a2d..18d66c80a8 100644 --- a/libs/community/langchain_community/vectorstores/pgvecto_rs.py +++ b/libs/community/langchain_community/vectorstores/pgvecto_rs.py @@ -1,32 +1,17 @@ from __future__ import annotations import uuid -from typing import Any, Iterable, List, Literal, Optional, Tuple, Type +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union -import numpy as np -import sqlalchemy from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from sqlalchemy import insert, select -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy.orm.session import Session - - -class _ORMBase(DeclarativeBase): - __tablename__: str - id: Mapped[uuid.UUID] - text: Mapped[str] - meta: Mapped[dict] - embedding: Mapped[np.ndarray] class PGVecto_rs(VectorStore): """VectorStore backed by pgvecto_rs.""" - _engine: sqlalchemy.engine.Engine - _table: Type[_ORMBase] + _store = None _embedding: Embeddings def __init__( @@ -45,28 +30,22 @@ class PGVecto_rs(VectorStore): db_url: Database URL. collection_name: Name of the collection. new_table: Whether to create a new table or connect to an existing one. - Defaults to False. + If true, the table will be dropped if exists, then recreated. + Defaults to False. """ try: - from pgvecto_rs.sqlalchemy import Vector + from pgvecto_rs.sdk import PGVectoRs except ImportError as e: raise ImportError( - "Unable to import pgvector_rs, please install with " - "`pip install pgvector_rs`." + "Unable to import pgvector_rs.sdk , please install with " + '`pip install "pgvector_rs[sdk]"`.' ) from e - - class _Table(_ORMBase): - __tablename__ = f"collection_{collection_name}" - id: Mapped[uuid.UUID] = mapped_column( - postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) - text: Mapped[str] = mapped_column(sqlalchemy.String) - meta: Mapped[dict] = mapped_column(postgresql.JSONB) - embedding: Mapped[np.ndarray] = mapped_column(Vector(dimension)) - - self._engine = sqlalchemy.create_engine(db_url) - self._table = _Table - self._table.__table__.create(self._engine, checkfirst=not new_table) # type: ignore + self._store = PGVectoRs( + db_url=db_url, + collection_name=collection_name, + dimension=dimension, + recreate=new_table, + ) self._embedding = embedding # ================ Create interface ================= @@ -90,7 +69,6 @@ class PGVecto_rs(VectorStore): dimension=dimension, db_url=db_url, collection_name=collection_name, - new_table=True, ) _self.add_texts(texts, metadatas, **kwargs) return _self @@ -148,19 +126,15 @@ class PGVecto_rs(VectorStore): List of ids of the added texts. """ + from pgvecto_rs.sdk import Record + embeddings = self._embedding.embed_documents(list(texts)) - with Session(self._engine) as _session: - results: List[str] = [] - for text, embedding, metadata in zip( - texts, embeddings, metadatas or [dict()] * len(list(texts)) - ): - t = insert(self._table).values( - text=text, meta=metadata, embedding=embedding - ) - id = _session.execute(t).inserted_primary_key[0] # type: ignore - results.append(str(id)) - _session.commit() - return results + records = [ + Record.from_text(text, embedding, meta) + for text, embedding, meta in zip(texts, embeddings, metadatas or []) + ] + self._store.insert(records) + return [str(record.id) for record in records] def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: """Run more documents through the embeddings and add to the vectorstore. @@ -185,31 +159,41 @@ class PGVecto_rs(VectorStore): distance_func: Literal[ "sqrt_euclid", "neg_dot_prod", "ned_cos" ] = "sqrt_euclid", + filter: Union[None, Dict[str, Any], Any] = None, **kwargs: Any, ) -> List[Tuple[Document, float]]: """Return docs most similar to query vector, with its score.""" - with Session(self._engine) as _session: - real_distance_func = ( - self._table.embedding.squared_euclidean_distance - if distance_func == "sqrt_euclid" - else self._table.embedding.negative_dot_product_distance - if distance_func == "neg_dot_prod" - else self._table.embedding.negative_cosine_distance - if distance_func == "ned_cos" - else None - ) - if real_distance_func is None: - raise ValueError("Invalid distance function") - t = ( - select(self._table, real_distance_func(query_vector).label("score")) - .order_by("score") - .limit(k) # type: ignore + from pgvecto_rs.sdk.filters import meta_contains + + distance_func_map = { + "sqrt_euclid": "<->", + "neg_dot_prod": "<#>", + "ned_cos": "<=>", + } + if filter is None: + real_filter = None + elif isinstance(filter, dict): + real_filter = meta_contains(filter) + else: + real_filter = filter + results = self._store.search( + query_vector, + distance_func_map[distance_func], + k, + filter=real_filter, + ) + + return [ + ( + Document( + page_content=res[0].text, + metadata=res[0].meta, + ), + res[1], ) - return [ - (Document(page_content=row[0].text, metadata=row[0].meta), row[1]) - for row in _session.execute(t) - ] + for res in results + ] def similarity_search_by_vector( self, @@ -218,11 +202,12 @@ class PGVecto_rs(VectorStore): distance_func: Literal[ "sqrt_euclid", "neg_dot_prod", "ned_cos" ] = "sqrt_euclid", + filter: Optional[Any] = None, **kwargs: Any, ) -> List[Document]: return [ doc - for doc, score in self.similarity_search_with_score_by_vector( + for doc, _score in self.similarity_search_with_score_by_vector( embedding, k, distance_func, **kwargs ) ] @@ -254,7 +239,7 @@ class PGVecto_rs(VectorStore): query_vector = self._embedding.embed_query(query) return [ doc - for doc, score in self.similarity_search_with_score_by_vector( + for doc, _score in self.similarity_search_with_score_by_vector( query_vector, k, distance_func, **kwargs ) ]