community[minor]: Update pgvecto_rs to use its high level sdk (#15574)

- **Description:** Update pgvecto_rs to use its high level sdk, 
  - **Issue:** fix #15173
pull/16049/head
盐粒 Yanli 6 months ago committed by GitHub
parent ce21392a21
commit ddf4e7c633
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,7 @@
"source": [ "source": [
"# PGVecto.rs\n", "# PGVecto.rs\n",
"\n", "\n",
"This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs)). You need to install SQLAlchemy >= 2 manually." "This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs))."
] ]
}, },
{ {
@ -15,10 +15,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"## Loading Environment Variables\n", "%pip install \"pgvecto_rs[sdk]\""
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()"
] ]
}, },
{ {
@ -32,8 +29,8 @@
"from langchain.docstore.document import Document\n", "from langchain.docstore.document import Document\n",
"from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain_community.document_loaders import TextLoader\n", "from langchain_community.document_loaders import TextLoader\n",
"from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs\n", "from langchain_community.embeddings.fake import FakeEmbeddings\n",
"from langchain_openai import OpenAIEmbeddings" "from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs"
] ]
}, },
{ {
@ -42,12 +39,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n", "loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n", "documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n", "docs = text_splitter.split_documents(documents)\n",
"\n", "\n",
"embeddings = OpenAIEmbeddings()" "embeddings = FakeEmbeddings(size=3)"
] ]
}, },
{ {
@ -176,7 +173,42 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs: List[Document] = db1.similarity_search(query, k=4)" "docs: List[Document] = db1.similarity_search(query, k=4)\n",
"for doc in docs:\n",
" print(doc.page_content)\n",
" print(\"======================\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Similarity Search with Filter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pgvecto_rs.sdk.filters import meta_contains\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs: List[Document] = db1.similarity_search(\n",
" query, k=4, filter=meta_contains({\"source\": \"../../modules/state_of_the_union.txt\"})\n",
")\n",
"\n",
"for doc in docs:\n",
" print(doc.page_content)\n",
" print(\"======================\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or:"
] ]
}, },
{ {
@ -185,6 +217,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs: List[Document] = db1.similarity_search(\n",
" query, k=4, filter={\"source\": \"../../modules/state_of_the_union.txt\"}\n",
")\n",
"\n",
"for doc in docs:\n", "for doc in docs:\n",
" print(doc.page_content)\n", " print(doc.page_content)\n",
" print(\"======================\")" " print(\"======================\")"
@ -207,7 +244,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.3" "version": "3.11.6"
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -1,32 +1,17 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from typing import Any, Iterable, List, Literal, Optional, Tuple, Type from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
import numpy as np
import sqlalchemy
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from sqlalchemy import insert, select
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.orm.session import Session
class _ORMBase(DeclarativeBase):
__tablename__: str
id: Mapped[uuid.UUID]
text: Mapped[str]
meta: Mapped[dict]
embedding: Mapped[np.ndarray]
class PGVecto_rs(VectorStore): class PGVecto_rs(VectorStore):
"""VectorStore backed by pgvecto_rs.""" """VectorStore backed by pgvecto_rs."""
_engine: sqlalchemy.engine.Engine _store = None
_table: Type[_ORMBase]
_embedding: Embeddings _embedding: Embeddings
def __init__( def __init__(
@ -45,28 +30,22 @@ class PGVecto_rs(VectorStore):
db_url: Database URL. db_url: Database URL.
collection_name: Name of the collection. collection_name: Name of the collection.
new_table: Whether to create a new table or connect to an existing one. new_table: Whether to create a new table or connect to an existing one.
Defaults to False. If true, the table will be dropped if exists, then recreated.
Defaults to False.
""" """
try: try:
from pgvecto_rs.sqlalchemy import Vector from pgvecto_rs.sdk import PGVectoRs
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Unable to import pgvector_rs, please install with " "Unable to import pgvector_rs.sdk , please install with "
"`pip install pgvector_rs`." '`pip install "pgvector_rs[sdk]"`.'
) from e ) from e
self._store = PGVectoRs(
class _Table(_ORMBase): db_url=db_url,
__tablename__ = f"collection_{collection_name}" collection_name=collection_name,
id: Mapped[uuid.UUID] = mapped_column( dimension=dimension,
postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 recreate=new_table,
) )
text: Mapped[str] = mapped_column(sqlalchemy.String)
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
embedding: Mapped[np.ndarray] = mapped_column(Vector(dimension))
self._engine = sqlalchemy.create_engine(db_url)
self._table = _Table
self._table.__table__.create(self._engine, checkfirst=not new_table) # type: ignore
self._embedding = embedding self._embedding = embedding
# ================ Create interface ================= # ================ Create interface =================
@ -90,7 +69,6 @@ class PGVecto_rs(VectorStore):
dimension=dimension, dimension=dimension,
db_url=db_url, db_url=db_url,
collection_name=collection_name, collection_name=collection_name,
new_table=True,
) )
_self.add_texts(texts, metadatas, **kwargs) _self.add_texts(texts, metadatas, **kwargs)
return _self return _self
@ -148,19 +126,15 @@ class PGVecto_rs(VectorStore):
List of ids of the added texts. List of ids of the added texts.
""" """
from pgvecto_rs.sdk import Record
embeddings = self._embedding.embed_documents(list(texts)) embeddings = self._embedding.embed_documents(list(texts))
with Session(self._engine) as _session: records = [
results: List[str] = [] Record.from_text(text, embedding, meta)
for text, embedding, metadata in zip( for text, embedding, meta in zip(texts, embeddings, metadatas or [])
texts, embeddings, metadatas or [dict()] * len(list(texts)) ]
): self._store.insert(records)
t = insert(self._table).values( return [str(record.id) for record in records]
text=text, meta=metadata, embedding=embedding
)
id = _session.execute(t).inserted_primary_key[0] # type: ignore
results.append(str(id))
_session.commit()
return results
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore. """Run more documents through the embeddings and add to the vectorstore.
@ -185,31 +159,41 @@ class PGVecto_rs(VectorStore):
distance_func: Literal[ distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos" "sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid", ] = "sqrt_euclid",
filter: Union[None, Dict[str, Any], Any] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs most similar to query vector, with its score.""" """Return docs most similar to query vector, with its score."""
with Session(self._engine) as _session:
real_distance_func = (
self._table.embedding.squared_euclidean_distance
if distance_func == "sqrt_euclid"
else self._table.embedding.negative_dot_product_distance
if distance_func == "neg_dot_prod"
else self._table.embedding.negative_cosine_distance
if distance_func == "ned_cos"
else None
)
if real_distance_func is None:
raise ValueError("Invalid distance function")
t = ( from pgvecto_rs.sdk.filters import meta_contains
select(self._table, real_distance_func(query_vector).label("score"))
.order_by("score") distance_func_map = {
.limit(k) # type: ignore "sqrt_euclid": "<->",
"neg_dot_prod": "<#>",
"ned_cos": "<=>",
}
if filter is None:
real_filter = None
elif isinstance(filter, dict):
real_filter = meta_contains(filter)
else:
real_filter = filter
results = self._store.search(
query_vector,
distance_func_map[distance_func],
k,
filter=real_filter,
)
return [
(
Document(
page_content=res[0].text,
metadata=res[0].meta,
),
res[1],
) )
return [ for res in results
(Document(page_content=row[0].text, metadata=row[0].meta), row[1]) ]
for row in _session.execute(t)
]
def similarity_search_by_vector( def similarity_search_by_vector(
self, self,
@ -218,11 +202,12 @@ class PGVecto_rs(VectorStore):
distance_func: Literal[ distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos" "sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid", ] = "sqrt_euclid",
filter: Optional[Any] = None,
**kwargs: Any, **kwargs: Any,
) -> List[Document]: ) -> List[Document]:
return [ return [
doc doc
for doc, score in self.similarity_search_with_score_by_vector( for doc, _score in self.similarity_search_with_score_by_vector(
embedding, k, distance_func, **kwargs embedding, k, distance_func, **kwargs
) )
] ]
@ -254,7 +239,7 @@ class PGVecto_rs(VectorStore):
query_vector = self._embedding.embed_query(query) query_vector = self._embedding.embed_query(query)
return [ return [
doc doc
for doc, score in self.similarity_search_with_score_by_vector( for doc, _score in self.similarity_search_with_score_by_vector(
query_vector, k, distance_func, **kwargs query_vector, k, distance_func, **kwargs
) )
] ]

Loading…
Cancel
Save