community[minor]: Update pgvecto_rs to use its high level sdk (#15574)

- **Description:** Update pgvecto_rs to use its high level sdk, 
  - **Issue:** fix #15173
pull/16049/head
盐粒 Yanli 6 months ago committed by GitHub
parent ce21392a21
commit ddf4e7c633
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,7 @@
"source": [
"# PGVecto.rs\n",
"\n",
"This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs)). You need to install SQLAlchemy >= 2 manually."
"This notebook shows how to use functionality related to the Postgres vector database ([pgvecto.rs](https://github.com/tensorchord/pgvecto.rs))."
]
},
{
@ -15,10 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"## Loading Environment Variables\n",
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()"
"%pip install \"pgvecto_rs[sdk]\""
]
},
{
@ -32,8 +29,8 @@
"from langchain.docstore.document import Document\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain_community.document_loaders import TextLoader\n",
"from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs\n",
"from langchain_openai import OpenAIEmbeddings"
"from langchain_community.embeddings.fake import FakeEmbeddings\n",
"from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs"
]
},
{
@ -42,12 +39,12 @@
"metadata": {},
"outputs": [],
"source": [
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
"loader = TextLoader(\"../../modules/state_of_the_union.txt\")\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
"\n",
"embeddings = OpenAIEmbeddings()"
"embeddings = FakeEmbeddings(size=3)"
]
},
{
@ -176,7 +173,42 @@
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs: List[Document] = db1.similarity_search(query, k=4)"
"docs: List[Document] = db1.similarity_search(query, k=4)\n",
"for doc in docs:\n",
" print(doc.page_content)\n",
" print(\"======================\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Similarity Search with Filter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pgvecto_rs.sdk.filters import meta_contains\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs: List[Document] = db1.similarity_search(\n",
" query, k=4, filter=meta_contains({\"source\": \"../../modules/state_of_the_union.txt\"})\n",
")\n",
"\n",
"for doc in docs:\n",
" print(doc.page_content)\n",
" print(\"======================\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or:"
]
},
{
@ -185,6 +217,11 @@
"metadata": {},
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs: List[Document] = db1.similarity_search(\n",
" query, k=4, filter={\"source\": \"../../modules/state_of_the_union.txt\"}\n",
")\n",
"\n",
"for doc in docs:\n",
" print(doc.page_content)\n",
" print(\"======================\")"
@ -207,7 +244,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.11.6"
}
},
"nbformat": 4,

@ -1,32 +1,17 @@
from __future__ import annotations
import uuid
from typing import Any, Iterable, List, Literal, Optional, Tuple, Type
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
import numpy as np
import sqlalchemy
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from sqlalchemy import insert, select
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.orm.session import Session
class _ORMBase(DeclarativeBase):
__tablename__: str
id: Mapped[uuid.UUID]
text: Mapped[str]
meta: Mapped[dict]
embedding: Mapped[np.ndarray]
class PGVecto_rs(VectorStore):
"""VectorStore backed by pgvecto_rs."""
_engine: sqlalchemy.engine.Engine
_table: Type[_ORMBase]
_store = None
_embedding: Embeddings
def __init__(
@ -45,28 +30,22 @@ class PGVecto_rs(VectorStore):
db_url: Database URL.
collection_name: Name of the collection.
new_table: Whether to create a new table or connect to an existing one.
Defaults to False.
If true, the table will be dropped if exists, then recreated.
Defaults to False.
"""
try:
from pgvecto_rs.sqlalchemy import Vector
from pgvecto_rs.sdk import PGVectoRs
except ImportError as e:
raise ImportError(
"Unable to import pgvector_rs, please install with "
"`pip install pgvector_rs`."
"Unable to import pgvector_rs.sdk , please install with "
'`pip install "pgvector_rs[sdk]"`.'
) from e
class _Table(_ORMBase):
__tablename__ = f"collection_{collection_name}"
id: Mapped[uuid.UUID] = mapped_column(
postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4
)
text: Mapped[str] = mapped_column(sqlalchemy.String)
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
embedding: Mapped[np.ndarray] = mapped_column(Vector(dimension))
self._engine = sqlalchemy.create_engine(db_url)
self._table = _Table
self._table.__table__.create(self._engine, checkfirst=not new_table) # type: ignore
self._store = PGVectoRs(
db_url=db_url,
collection_name=collection_name,
dimension=dimension,
recreate=new_table,
)
self._embedding = embedding
# ================ Create interface =================
@ -90,7 +69,6 @@ class PGVecto_rs(VectorStore):
dimension=dimension,
db_url=db_url,
collection_name=collection_name,
new_table=True,
)
_self.add_texts(texts, metadatas, **kwargs)
return _self
@ -148,19 +126,15 @@ class PGVecto_rs(VectorStore):
List of ids of the added texts.
"""
from pgvecto_rs.sdk import Record
embeddings = self._embedding.embed_documents(list(texts))
with Session(self._engine) as _session:
results: List[str] = []
for text, embedding, metadata in zip(
texts, embeddings, metadatas or [dict()] * len(list(texts))
):
t = insert(self._table).values(
text=text, meta=metadata, embedding=embedding
)
id = _session.execute(t).inserted_primary_key[0] # type: ignore
results.append(str(id))
_session.commit()
return results
records = [
Record.from_text(text, embedding, meta)
for text, embedding, meta in zip(texts, embeddings, metadatas or [])
]
self._store.insert(records)
return [str(record.id) for record in records]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
@ -185,31 +159,41 @@ class PGVecto_rs(VectorStore):
distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid",
filter: Union[None, Dict[str, Any], Any] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query vector, with its score."""
with Session(self._engine) as _session:
real_distance_func = (
self._table.embedding.squared_euclidean_distance
if distance_func == "sqrt_euclid"
else self._table.embedding.negative_dot_product_distance
if distance_func == "neg_dot_prod"
else self._table.embedding.negative_cosine_distance
if distance_func == "ned_cos"
else None
)
if real_distance_func is None:
raise ValueError("Invalid distance function")
t = (
select(self._table, real_distance_func(query_vector).label("score"))
.order_by("score")
.limit(k) # type: ignore
from pgvecto_rs.sdk.filters import meta_contains
distance_func_map = {
"sqrt_euclid": "<->",
"neg_dot_prod": "<#>",
"ned_cos": "<=>",
}
if filter is None:
real_filter = None
elif isinstance(filter, dict):
real_filter = meta_contains(filter)
else:
real_filter = filter
results = self._store.search(
query_vector,
distance_func_map[distance_func],
k,
filter=real_filter,
)
return [
(
Document(
page_content=res[0].text,
metadata=res[0].meta,
),
res[1],
)
return [
(Document(page_content=row[0].text, metadata=row[0].meta), row[1])
for row in _session.execute(t)
]
for res in results
]
def similarity_search_by_vector(
self,
@ -218,11 +202,12 @@ class PGVecto_rs(VectorStore):
distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid",
filter: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
return [
doc
for doc, score in self.similarity_search_with_score_by_vector(
for doc, _score in self.similarity_search_with_score_by_vector(
embedding, k, distance_func, **kwargs
)
]
@ -254,7 +239,7 @@ class PGVecto_rs(VectorStore):
query_vector = self._embedding.embed_query(query)
return [
doc
for doc, score in self.similarity_search_with_score_by_vector(
for doc, _score in self.similarity_search_with_score_by_vector(
query_vector, k, distance_func, **kwargs
)
]

Loading…
Cancel
Save