From 3bae595182ccff8a570515695285c016c8c0a5e1 Mon Sep 17 00:00:00 2001 From: Sheng Han Lim Date: Thu, 1 Jun 2023 08:31:52 +0800 Subject: [PATCH] Add texts with embeddings to PGVector wrapper (#5500) Similar to #1813 for faiss, this PR is to extend functionality to pass text and its vector pair to initialize and add embeddings to the PGVector wrapper. Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @dev2049 --- langchain/vectorstores/pgvector.py | 125 +++++++++++++++++- .../vectorstores/test_pgvector.py | 16 +++ 2 files changed, 134 insertions(+), 7 deletions(-) diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index 64595230..161e8e2c 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -192,6 +192,72 @@ class PGVector(VectorStore): def get_collection(self, session: Session) -> Optional["CollectionStore"]: return CollectionStore.get_by_name(session, self.collection_name) + @classmethod + def __from( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + connection_string = cls.get_connection_string(kwargs) + + store = cls( + connection_string=connection_string, + collection_name=collection_name, + embedding_function=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + ) + + store.add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + return store + + def add_embeddings( + self, + texts: List[str], + embeddings: List[List[float]], + metadatas: List[dict], + ids: List[str], + **kwargs: Any, + ) -> None: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): + embedding_store = EmbeddingStore( + embedding=embedding, + document=text, + cmetadata=metadata, + custom_id=id, + ) + collection.embeddings.append(embedding_store) + session.add(embedding_store) + session.commit() + def add_texts( self, texts: Iterable[str], @@ -380,19 +446,64 @@ class PGVector(VectorStore): "Either pass it as a parameter or set the PGVECTOR_CONNECTION_STRING environment variable. """ + embeddings = embedding.embed_documents(list(texts)) - connection_string = cls.get_connection_string(kwargs) - - store = cls( - connection_string=connection_string, + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, collection_name=collection_name, - embedding_function=embedding, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, + **kwargs, ) - store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs) - return store + @classmethod + def from_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """Construct PGVector wrapper from raw documents and pre- + generated embeddings. + + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + + Example: + .. code-block:: python + + from langchain import PGVector + from langchain.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) @classmethod def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: diff --git a/tests/integration_tests/vectorstores/test_pgvector.py b/tests/integration_tests/vectorstores/test_pgvector.py index 8ad7f1bc..b8d3314d 100644 --- a/tests/integration_tests/vectorstores/test_pgvector.py +++ b/tests/integration_tests/vectorstores/test_pgvector.py @@ -49,6 +49,22 @@ def test_pgvector() -> None: assert output == [Document(page_content="foo")] +def test_pgvector_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = PGVector.from_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_pgvector_with_metadatas() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"]