You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/vectorstores/pgvecto_rs.py

246 lines
7.7 KiB
Python

from __future__ import annotations
import uuid
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
class PGVecto_rs(VectorStore):
"""VectorStore backed by pgvecto_rs."""
_store = None
_embedding: Embeddings
def __init__(
self,
embedding: Embeddings,
dimension: int,
db_url: str,
collection_name: str,
new_table: bool = False,
) -> None:
"""Initialize a PGVecto_rs vectorstore.
Args:
embedding: Embeddings to use.
dimension: Dimension of the embeddings.
db_url: Database URL.
collection_name: Name of the collection.
new_table: Whether to create a new table or connect to an existing one.
If true, the table will be dropped if exists, then recreated.
Defaults to False.
"""
try:
from pgvecto_rs.sdk import PGVectoRs
except ImportError as e:
raise ImportError(
"Unable to import pgvector_rs.sdk , please install with "
'`pip install "pgvecto_rs[sdk]"`.'
) from e
self._store = PGVectoRs(
db_url=db_url,
collection_name=collection_name,
dimension=dimension,
recreate=new_table,
)
self._embedding = embedding
# ================ Create interface =================
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
db_url: str = "",
collection_name: str = str(uuid.uuid4().hex),
**kwargs: Any,
) -> PGVecto_rs:
"""Return VectorStore initialized from texts and optional metadatas."""
sample_embedding = embedding.embed_query("Hello pgvecto_rs!")
dimension = len(sample_embedding)
if db_url is None:
raise ValueError("db_url must be provided")
_self: PGVecto_rs = cls(
embedding=embedding,
dimension=dimension,
db_url=db_url,
collection_name=collection_name,
)
_self.add_texts(texts, metadatas, **kwargs)
return _self
@classmethod
def from_documents(
cls,
documents: List[Document],
embedding: Embeddings,
db_url: str = "",
collection_name: str = str(uuid.uuid4().hex),
**kwargs: Any,
) -> PGVecto_rs:
"""Return VectorStore initialized from documents."""
texts = [document.page_content for document in documents]
metadatas = [document.metadata for document in documents]
return cls.from_texts(
texts, embedding, metadatas, db_url, collection_name, **kwargs
)
@classmethod
def from_collection_name(
cls,
embedding: Embeddings,
db_url: str,
collection_name: str,
) -> PGVecto_rs:
"""Create new empty vectorstore with collection_name.
Or connect to an existing vectorstore in database if exists.
Arguments should be the same as when the vectorstore was created."""
sample_embedding = embedding.embed_query("Hello pgvecto_rs!")
return cls(
embedding=embedding,
dimension=len(sample_embedding),
db_url=db_url,
collection_name=collection_name,
)
# ================ Insert interface =================
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
kwargs: vectorstore specific parameters
Returns:
List of ids of the added texts.
"""
from pgvecto_rs.sdk import Record
embeddings = self._embedding.embed_documents(list(texts))
records = [
Record.from_text(text, embedding, meta)
for text, embedding, meta in zip(texts, embeddings, metadatas or [])
]
self._store.insert(records) # type: ignore[union-attr]
return [str(record.id) for record in records]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Run more documents through the embeddings and add to the vectorstore.
Args:
documents (List[Document]): List of documents to add to the vectorstore.
Returns:
List of ids of the added documents.
"""
return self.add_texts(
[document.page_content for document in documents],
[document.metadata for document in documents],
**kwargs,
)
# ================ Query interface =================
def similarity_search_with_score_by_vector(
self,
query_vector: List[float],
k: int = 4,
distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid",
filter: Union[None, Dict[str, Any], Any] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query vector, with its score."""
from pgvecto_rs.sdk.filters import meta_contains
distance_func_map = {
"sqrt_euclid": "<->",
"neg_dot_prod": "<#>",
"ned_cos": "<=>",
}
if filter is None:
real_filter = None
elif isinstance(filter, dict):
real_filter = meta_contains(filter)
else:
real_filter = filter
results = self._store.search( # type: ignore[union-attr]
query_vector,
distance_func_map[distance_func],
k,
filter=real_filter,
)
return [
(
Document(
page_content=res[0].text,
metadata=res[0].meta,
),
res[1],
)
for res in results
]
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid",
filter: Optional[Any] = None,
**kwargs: Any,
) -> List[Document]:
return [
doc
for doc, _score in self.similarity_search_with_score_by_vector(
embedding, k, distance_func, **kwargs
)
]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid",
**kwargs: Any,
) -> List[Tuple[Document, float]]:
query_vector = self._embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
query_vector, k, distance_func, **kwargs
)
def similarity_search(
self,
query: str,
k: int = 4,
distance_func: Literal[
"sqrt_euclid", "neg_dot_prod", "ned_cos"
] = "sqrt_euclid",
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query."""
query_vector = self._embedding.embed_query(query)
return [
doc
for doc, _score in self.similarity_search_with_score_by_vector(
query_vector, k, distance_func, **kwargs
)
]