diff --git a/langchain/vectorstores/pgvector.py b/langchain/vectorstores/pgvector.py index 64595230..161e8e2c 100644 --- a/langchain/vectorstores/pgvector.py +++ b/langchain/vectorstores/pgvector.py @@ -192,6 +192,72 @@ class PGVector(VectorStore): def get_collection(self, session: Session) -> Optional["CollectionStore"]: return CollectionStore.get_by_name(session, self.collection_name) + @classmethod + def __from( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + connection_string = cls.get_connection_string(kwargs) + + store = cls( + connection_string=connection_string, + collection_name=collection_name, + embedding_function=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + ) + + store.add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + return store + + def add_embeddings( + self, + texts: List[str], + embeddings: List[List[float]], + metadatas: List[dict], + ids: List[str], + **kwargs: Any, + ) -> None: + """Add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + kwargs: vectorstore specific parameters + """ + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): + embedding_store = EmbeddingStore( + embedding=embedding, + document=text, + cmetadata=metadata, + custom_id=id, + ) + collection.embeddings.append(embedding_store) + session.add(embedding_store) + session.commit() + def add_texts( self, texts: Iterable[str], @@ -380,19 +446,64 @@ class PGVector(VectorStore): "Either pass it as a parameter or set the PGVECTOR_CONNECTION_STRING environment variable. """ + embeddings = embedding.embed_documents(list(texts)) - connection_string = cls.get_connection_string(kwargs) - - store = cls( - connection_string=connection_string, + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, collection_name=collection_name, - embedding_function=embedding, distance_strategy=distance_strategy, pre_delete_collection=pre_delete_collection, + **kwargs, ) - store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs) - return store + @classmethod + def from_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DistanceStrategy.COSINE, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """Construct PGVector wrapper from raw documents and pre- + generated embeddings. + + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + + Example: + .. code-block:: python + + from langchain import PGVector + from langchain.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return cls.__from( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) @classmethod def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: diff --git a/tests/integration_tests/vectorstores/test_pgvector.py b/tests/integration_tests/vectorstores/test_pgvector.py index 8ad7f1bc..b8d3314d 100644 --- a/tests/integration_tests/vectorstores/test_pgvector.py +++ b/tests/integration_tests/vectorstores/test_pgvector.py @@ -49,6 +49,22 @@ def test_pgvector() -> None: assert output == [Document(page_content="foo")] +def test_pgvector_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = PGVector.from_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection_string=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_pgvector_with_metadatas() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"]