Add texts with embeddings to PGVector wrapper (#5500)

Similar to #1813 for faiss, this PR is to extend functionality to pass
text and its vector pair to initialize and add embeddings to the
PGVector wrapper.

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:
  - @dev2049
This commit is contained in:
Sheng Han Lim 2023-06-01 08:31:52 +08:00 committed by GitHub
parent 8d07ba0d51
commit 3bae595182
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 134 additions and 7 deletions

View File

@ -192,6 +192,72 @@ class PGVector(VectorStore):
def get_collection(self, session: Session) -> Optional["CollectionStore"]: def get_collection(self, session: Session) -> Optional["CollectionStore"]:
return CollectionStore.get_by_name(session, self.collection_name) return CollectionStore.get_by_name(session, self.collection_name)
@classmethod
def __from(
cls,
texts: List[str],
embeddings: List[List[float]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> PGVector:
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]
if not metadatas:
metadatas = [{} for _ in texts]
connection_string = cls.get_connection_string(kwargs)
store = cls(
connection_string=connection_string,
collection_name=collection_name,
embedding_function=embedding,
distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection,
)
store.add_embeddings(
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
)
return store
def add_embeddings(
self,
texts: List[str],
embeddings: List[List[float]],
metadatas: List[dict],
ids: List[str],
**kwargs: Any,
) -> None:
"""Add embeddings to the vectorstore.
Args:
texts: Iterable of strings to add to the vectorstore.
embeddings: List of list of embedding vectors.
metadatas: List of metadatas associated with the texts.
kwargs: vectorstore specific parameters
"""
with Session(self._conn) as session:
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
embedding_store = EmbeddingStore(
embedding=embedding,
document=text,
cmetadata=metadata,
custom_id=id,
)
collection.embeddings.append(embedding_store)
session.add(embedding_store)
session.commit()
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
@ -380,19 +446,64 @@ class PGVector(VectorStore):
"Either pass it as a parameter "Either pass it as a parameter
or set the PGVECTOR_CONNECTION_STRING environment variable. or set the PGVECTOR_CONNECTION_STRING environment variable.
""" """
embeddings = embedding.embed_documents(list(texts))
connection_string = cls.get_connection_string(kwargs) return cls.__from(
texts,
store = cls( embeddings,
connection_string=connection_string, embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name, collection_name=collection_name,
embedding_function=embedding,
distance_strategy=distance_strategy, distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection, pre_delete_collection=pre_delete_collection,
**kwargs,
) )
store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs) @classmethod
return store def from_embeddings(
cls,
text_embeddings: List[Tuple[str, List[float]]],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
ids: Optional[List[str]] = None,
pre_delete_collection: bool = False,
**kwargs: Any,
) -> PGVector:
"""Construct PGVector wrapper from raw documents and pre-
generated embeddings.
Return VectorStore initialized from documents and embeddings.
Postgres connection string is required
"Either pass it as a parameter
or set the PGVECTOR_CONNECTION_STRING environment variable.
Example:
.. code-block:: python
from langchain import PGVector
from langchain.embeddings import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
text_embeddings = embeddings.embed_documents(texts)
text_embedding_pairs = list(zip(texts, text_embeddings))
faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings)
"""
texts = [t[0] for t in text_embeddings]
embeddings = [t[1] for t in text_embeddings]
return cls.__from(
texts,
embeddings,
embedding,
metadatas=metadatas,
ids=ids,
collection_name=collection_name,
distance_strategy=distance_strategy,
pre_delete_collection=pre_delete_collection,
**kwargs,
)
@classmethod @classmethod
def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:

View File

@ -49,6 +49,22 @@ def test_pgvector() -> None:
assert output == [Document(page_content="foo")] assert output == [Document(page_content="foo")]
def test_pgvector_embeddings() -> None:
"""Test end to end construction with embeddings and search."""
texts = ["foo", "bar", "baz"]
text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts)
text_embedding_pairs = list(zip(texts, text_embeddings))
docsearch = PGVector.from_embeddings(
text_embeddings=text_embedding_pairs,
collection_name="test_collection",
embedding=FakeEmbeddingsWithAdaDimension(),
connection_string=CONNECTION_STRING,
pre_delete_collection=True,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]
def test_pgvector_with_metadatas() -> None: def test_pgvector_with_metadatas() -> None:
"""Test end to end construction and search.""" """Test end to end construction and search."""
texts = ["foo", "bar", "baz"] texts = ["foo", "bar", "baz"]