From 6fc24743b7899c1ba4c397d1669d6821e64d7c0c Mon Sep 17 00:00:00 2001 From: Raouf Chebri Date: Wed, 5 Jul 2023 17:10:10 +0200 Subject: [PATCH] Add pg_hnsw vectorstore integration (#6893) Hi @rlancemartin, @eyurtsev! - Description: Adding HNSW extension support for Postgres. Similar to pgvector vectorstore, with 3 differences 1. it uses HNSW extension for exact and ANN searches, 2. Vectors are of type array of real 3. Only supports L2 - Dependencies: [HNSW](https://github.com/knizhnik/hnsw) extension for Postgres - Example: ```python db = HNSWVectoreStore.from_documents( embedding=embeddings, documents=docs, collection_name=collection_name, connection_string=connection_string ) query = "What did the president say about Ketanji Brown Jackson" docs_with_score: List[Tuple[Document, float]] = db.similarity_search_with_score(query) ``` The example notebook is in the PR too. --- .../integrations/pgembedding.ipynb | 338 ++++++++++++ langchain/vectorstores/__init__.py | 2 + langchain/vectorstores/pgembedding.py | 510 ++++++++++++++++++ 3 files changed, 850 insertions(+) create mode 100644 docs/extras/modules/data_connection/vectorstores/integrations/pgembedding.ipynb create mode 100644 langchain/vectorstores/pgembedding.py diff --git a/docs/extras/modules/data_connection/vectorstores/integrations/pgembedding.ipynb b/docs/extras/modules/data_connection/vectorstores/integrations/pgembedding.ipynb new file mode 100644 index 0000000000..051e7ef569 --- /dev/null +++ b/docs/extras/modules/data_connection/vectorstores/integrations/pgembedding.ipynb @@ -0,0 +1,338 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "1292f057", + "metadata": {}, + "source": [ + "# pg_hnsw\n", + "\n", + "> [pg_embedding](https://github.com/knizhnik/hnsw) is an open-source vector similarity search for `Postgres` that uses Hierarchical Navigable Small Worlds for approximate nearest neighbor search.\n", + "\n", + "It supports:\n", + "- exact and approximate nearest neighbor search using HNSW\n", + "- L2 distance\n", + "\n", + "This notebook shows how to use the Postgres vector database (`PGEmbedding`).\n", + "\n", + "> The PGEmbedding integration creates the pg_embedding extension for you, but you run the following Postgres query to add it:\n", + "```sql\n", + "CREATE EXTENSION embedding;\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6214221", + "metadata": {}, + "outputs": [], + "source": [ + "# Pip install necessary package\n", + "!pip install openai\n", + "!pip install psycopg2-binary\n", + "!pip install tiktoken" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "b2e49694", + "metadata": {}, + "source": [ + "Add the OpenAI API Key to the environment variables to use `OpenAIEmbeddings`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1dcc8d99", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "OpenAI API Key:········\n" + ] + } + ], + "source": [ + "import os\n", + "import getpass\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9719ea68", + "metadata": {}, + "outputs": [], + "source": [ + "## Loading Environment Variables\n", + "from typing import List, Tuple" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfd1f38d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import PGEmbedding\n", + "from langchain.document_loaders import TextLoader\n", + "from langchain.docstore.document import Document" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8fab8cc2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Database Url:········\n" + ] + } + ], + "source": [ + "os.environ[\"DATABASE_URL\"] = getpass.getpass(\"Database Url:\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bef17115", + "metadata": {}, + "outputs": [], + "source": [ + "loader = TextLoader(\"state_of_the_union.txt\")\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "connection_string = os.environ.get(\"DATABASE_URL\")\n", + "collection_name = \"state_of_the_union\"" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "743abfaa", + "metadata": {}, + "outputs": [], + "source": [ + "db = PGEmbedding.from_documents(\n", + " embedding=embeddings,\n", + " documents=docs,\n", + " collection_name=collection_name,\n", + " connection_string=connection_string,\n", + ")\n", + "\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs_with_score: List[Tuple[Document, float]] = db.similarity_search_with_score(query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41ce4c4e", + "metadata": {}, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "7ef7b052", + "metadata": {}, + "source": [ + "## Working with vectorstore in Postgres" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "939151f7", + "metadata": {}, + "source": [ + "### Uploading a vectorstore in PG " + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "595ac511", + "metadata": {}, + "outputs": [], + "source": [ + "db = PGEmbedding.from_documents(\n", + " embedding=embeddings,\n", + " documents=docs,\n", + " collection_name=collection_name,\n", + " connection_string=connection_string,\n", + " pre_delete_collection=False,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f9510e6b", + "metadata": {}, + "source": [ + "### Create HNSW Index\n", + "By default, the extension performs a sequential scan search, with 100% recall. You might consider creating an HNSW index for approximate nearest neighbor (ANN) search to speed up `similarity_search_with_score` execution time. To create the HNSW index on your vector column, use a `create_hnsw_index` function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d1981fa", + "metadata": {}, + "outputs": [], + "source": [ + "PGEmbedding.create_hnsw_index(\n", + " max_elements=10000, dims=1536, m=8, ef_construction=16, ef_search=16\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "7adacf29", + "metadata": {}, + "source": [ + "The function above is equivalent to running the below SQL query:\n", + "```sql\n", + "CREATE INDEX ON vectors USING hnsw(vec) WITH (maxelements=10000, dims=1536, m=3, efconstruction=16, efsearch=16);\n", + "```\n", + "The HNSW index options used in the statement above include:\n", + "\n", + "- maxelements: Defines the maximum number of elements indexed. This is a required parameter. The example shown above has a value of 3. A real-world example would have a much large value, such as 1000000. An \"element\" refers to a data point (a vector) in the dataset, which is represented as a node in the HNSW graph. Typically, you would set this option to a value able to accommodate the number of rows in your in your dataset.\n", + "- dims: Defines the number of dimensions in your vector data. This is a required parameter. A small value is used in the example above. If you are storing data generated using OpenAI's text-embedding-ada-002 model, which supports 1536 dimensions, you would define a value of 1536, for example.\n", + "- m: Defines the maximum number of bi-directional links (also referred to as \"edges\") created for each node during graph construction.\n", + "The following additional index options are supported:\n", + "\n", + "- efConstruction: Defines the number of nearest neighbors considered during index construction. The default value is 32.\n", + "- efsearch: Defines the number of nearest neighbors considered during index search. The default value is 32.\n", + "For information about how you can configure these options to influence the HNSW algorithm, refer to [Tuning the HNSW algorithm](https://neon-next-git-dprice-hnsw-extension-neondatabase.vercel.app/docs/extensions/hnsw#tuning-the-hnsw-algorithm)." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "528893fb", + "metadata": {}, + "source": [ + "### Retrieving a vectorstore in PG" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "b6162b1c", + "metadata": {}, + "outputs": [], + "source": [ + "store = PGEmbedding(\n", + " connection_string=connection_string,\n", + " embedding_function=embeddings,\n", + " collection_name=collection_name,\n", + ")\n", + "\n", + "retriever = store.as_retriever()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1a5fedb1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "VectorStoreRetriever(vectorstore=, search_type='similarity', search_kwargs={})" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "retriever" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0cefc938", + "metadata": {}, + "outputs": [], + "source": [ + "db1 = PGEmbedding.from_existing_index(\n", + " embedding=embeddings,\n", + " collection_name=collection_name,\n", + " pre_delete_collection=False,\n", + " connection_string=connection_string,\n", + ")\n", + "\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs_with_score: List[Tuple[Document, float]] = db1.similarity_search_with_score(query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85cde495", + "metadata": {}, + "outputs": [], + "source": [ + "for doc, score in docs_with_score:\n", + " print(\"-\" * 80)\n", + " print(\"Score: \", score)\n", + " print(doc.page_content)\n", + " print(\"-\" * 80)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/vectorstores/__init__.py b/langchain/vectorstores/__init__.py index 345aeec481..cb3db2fa56 100644 --- a/langchain/vectorstores/__init__.py +++ b/langchain/vectorstores/__init__.py @@ -24,6 +24,7 @@ from langchain.vectorstores.milvus import Milvus from langchain.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch from langchain.vectorstores.myscale import MyScale, MyScaleSettings from langchain.vectorstores.opensearch_vector_search import OpenSearchVectorSearch +from langchain.vectorstores.pgembedding import PGEmbedding from langchain.vectorstores.pinecone import Pinecone from langchain.vectorstores.qdrant import Qdrant from langchain.vectorstores.redis import Redis @@ -56,6 +57,7 @@ __all__ = [ "DocArrayInMemorySearch", "ElasticVectorSearch", "FAISS", + "PGEmbedding", "Hologres", "LanceDB", "MatchingEngine", diff --git a/langchain/vectorstores/pgembedding.py b/langchain/vectorstores/pgembedding.py new file mode 100644 index 0000000000..9dc76d9a9a --- /dev/null +++ b/langchain/vectorstores/pgembedding.py @@ -0,0 +1,510 @@ +"""VectorStore wrapper around a Postgres database.""" +from __future__ import annotations + +import logging +import uuid +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type + +import sqlalchemy +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import JSON, UUID +from sqlalchemy.orm import Session, declarative_base, relationship + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from langchain.vectorstores.base import VectorStore + +Base = declarative_base() # type: Any + + +ADA_TOKEN_COUNT = 1536 +_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain" + + +class BaseModel(Base): + __abstract__ = True + uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + + +class CollectionStore(BaseModel): + __tablename__ = "langchain_pg_collection" + + name = sqlalchemy.Column(sqlalchemy.String) + cmetadata = sqlalchemy.Column(JSON) + + embeddings = relationship( + "EmbeddingStore", + back_populates="collection", + passive_deletes=True, + ) + + @classmethod + def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: + return session.query(cls).filter(cls.name == name).first() # type: ignore + + @classmethod + def get_or_create( + cls, + session: Session, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ + created = False + collection = cls.get_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + session.commit() + created = True + return collection, created + + +class EmbeddingStore(BaseModel): + __tablename__ = "langchain_pg_embedding" + + collection_id = sqlalchemy.Column( + UUID(as_uuid=True), + sqlalchemy.ForeignKey( + f"{CollectionStore.__tablename__}.uuid", + ondelete="CASCADE", + ), + ) + collection = relationship(CollectionStore, back_populates="embeddings") + + embedding = sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL)) # type: ignore + document = sqlalchemy.Column(sqlalchemy.String, nullable=True) + cmetadata = sqlalchemy.Column(JSON, nullable=True) + + # custom_id : any user defined id + custom_id = sqlalchemy.Column(sqlalchemy.String, nullable=True) + + +class QueryResult: + EmbeddingStore: EmbeddingStore + distance: float + + +class PGEmbedding(VectorStore): + """ + VectorStore implementation using Postgres and the pg_embedding extension. + pg_embedding uses sequential scan by default. but you can create a HNSW index + using the create_hnsw_index method. + - `connection_string` is a postgres connection string. + - `embedding_function` any embedding function implementing + `langchain.embeddings.base.Embeddings` interface. + - `collection_name` is the name of the collection to use. (default: langchain) + - NOTE: This is not the name of the table, but the name of the collection. + The tables will be created when initializing the store (if not exists) + So, make sure the user has the right permissions to create tables. + - `distance_strategy` is the distance strategy to use. (default: EUCLIDEAN) + - `EUCLIDEAN` is the euclidean distance. + - `pre_delete_collection` if True, will delete the collection if it exists. + (default: False) + - Useful for testing. + """ + + def __init__( + self, + connection_string: str, + embedding_function: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + collection_metadata: Optional[dict] = None, + pre_delete_collection: bool = False, + logger: Optional[logging.Logger] = None, + ) -> None: + self.connection_string = connection_string + self.embedding_function = embedding_function + self.collection_name = collection_name + self.collection_metadata = collection_metadata + self.pre_delete_collection = pre_delete_collection + self.logger = logger or logging.getLogger(__name__) + self.__post_init__() + + def __post_init__( + self, + ) -> None: + self._conn = self.connect() + self.create_hnsw_extension() + self.create_tables_if_not_exists() + self.create_collection() + + def connect(self) -> sqlalchemy.engine.Connection: + engine = sqlalchemy.create_engine(self.connection_string) + conn = engine.connect() + return conn + + def create_hnsw_extension(self) -> None: + try: + with Session(self._conn) as session: + statement = sqlalchemy.text( + "CREATE EXTENSION IF NOT EXISTS pg_embedding" + ) + session.execute(statement) + session.commit() + except Exception as e: + self.logger.exception(e) + + def create_tables_if_not_exists(self) -> None: + with self._conn.begin(): + Base.metadata.create_all(self._conn) + + def drop_tables(self) -> None: + with self._conn.begin(): + Base.metadata.drop_all(self._conn) + + def create_collection(self) -> None: + if self.pre_delete_collection: + self.delete_collection() + with Session(self._conn) as session: + CollectionStore.get_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) + + def create_hnsw_index( + self, + max_elements: int = 10000, + dims: int = ADA_TOKEN_COUNT, + m: int = 8, + ef_construction: int = 16, + ef_search: int = 16, + ) -> None: + create_index_query = sqlalchemy.text( + "CREATE INDEX IF NOT EXISTS langchain_pg_embedding_idx " + "ON langchain_pg_embedding USING hnsw (embedding) " + "WITH (" + "maxelements = {}, " + "dims = {}, " + "m = {}, " + "efconstruction = {}, " + "efsearch = {}" + ");".format(max_elements, dims, m, ef_construction, ef_search) + ) + + # Execute the queries + try: + with Session(self._conn) as session: + # Create the HNSW index + session.execute(create_index_query) + session.commit() + print("HNSW extension and index created successfully.") + except Exception as e: + print(f"Failed to create HNSW extension or index: {e}") + + def delete_collection(self) -> None: + self.logger.debug("Trying to delete collection") + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + session.delete(collection) + session.commit() + + def get_collection(self, session: Session) -> Optional["CollectionStore"]: + return CollectionStore.get_by_name(session, self.collection_name) + + @classmethod + def _initialize_from_embeddings( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGEmbedding: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + connection_string = cls.get_connection_string(kwargs) + + store = cls( + connection_string=connection_string, + collection_name=collection_name, + embedding_function=embedding, + pre_delete_collection=pre_delete_collection, + ) + + store.add_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + return store + + def add_embeddings( + self, + texts: List[str], + embeddings: List[List[float]], + metadatas: List[dict], + ids: List[str], + **kwargs: Any, + ) -> None: + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): + embedding_store = EmbeddingStore( + embedding=embedding, + document=text, + cmetadata=metadata, + custom_id=id, + ) + collection.embeddings.append(embedding_store) + session.add(embedding_store) + session.commit() + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + embeddings = self.embedding_function.embed_documents(list(texts)) + + if not metadatas: + metadatas = [{} for _ in texts] + + with Session(self._conn) as session: + collection = self.get_collection(session) + if not collection: + raise ValueError("Collection not found") + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids): + embedding_store = EmbeddingStore( + embedding=embedding, + document=text, + cmetadata=metadata, + custom_id=id, + ) + collection.embeddings.append(embedding_store) + session.add(embedding_store) + session.commit() + + return ids + + def similarity_search( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + embedding = self.embedding_function.embed_query(text=query) + return self.similarity_search_by_vector( + embedding=embedding, + k=k, + filter=filter, + ) + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + embedding = self.embedding_function.embed_query(query) + docs = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return docs + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + with Session(self._conn) as session: + collection = self.get_collection(session) + set_enable_seqscan_stmt = sqlalchemy.text("SET enable_seqscan = off") + session.execute(set_enable_seqscan_stmt) + if not collection: + raise ValueError("Collection not found") + + filter_by = EmbeddingStore.collection_id == collection.uuid + + if filter is not None: + filter_clauses = [] + for key, value in filter.items(): + IN = "in" + if isinstance(value, dict) and IN in map(str.lower, value): + value_case_insensitive = { + k.lower(): v for k, v in value.items() + } + filter_by_metadata = EmbeddingStore.cmetadata[key].astext.in_( + value_case_insensitive[IN] + ) + filter_clauses.append(filter_by_metadata) + else: + filter_by_metadata = EmbeddingStore.cmetadata[ + key + ].astext == str(value) + filter_clauses.append(filter_by_metadata) + + filter_by = sqlalchemy.and_(filter_by, *filter_clauses) + + results: List[QueryResult] = ( + session.query( + EmbeddingStore, + func.abs(EmbeddingStore.embedding.op("<->")(embedding)).label( + "distance" + ), + ) # Specify the columns you need here, e.g., EmbeddingStore.embedding + .filter(filter_by) + .order_by( + func.abs(EmbeddingStore.embedding.op("<->")(embedding)).asc() + ) # Using PostgreSQL specific operator with the correct column name + .limit(k) + .all() + ) + + docs = [ + ( + Document( + page_content=result.EmbeddingStore.document, + metadata=result.EmbeddingStore.cmetadata, + ), + result.distance if self.embedding_function is not None else None, + ) + for result in results + ] + return docs + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + docs_and_scores = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return [doc for doc, _ in docs_and_scores] + + @classmethod + def from_texts( + cls: Type[PGEmbedding], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGEmbedding: + embeddings = embedding.embed_documents(list(texts)) + + return cls._initialize_from_embeddings( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + def from_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGEmbedding: + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return cls._initialize_from_embeddings( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + + @classmethod + def from_existing_index( + cls: Type[PGEmbedding], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGEmbedding: + connection_string = cls.get_connection_string(kwargs) + + store = cls( + connection_string=connection_string, + collection_name=collection_name, + embedding_function=embedding, + pre_delete_collection=pre_delete_collection, + ) + + return store + + @classmethod + def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: + connection_string: str = get_from_dict_or_env( + data=kwargs, + key="connection_string", + env_key="POSTGRES_CONNECTION_STRING", + ) + + if not connection_string: + raise ValueError( + "Postgres connection string is required" + "Either pass it as a parameter" + "or set the POSTGRES_CONNECTION_STRING environment variable." + ) + + return connection_string + + @classmethod + def from_documents( + cls: Type[PGEmbedding], + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGEmbedding: + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + connection_string = cls.get_connection_string(kwargs) + + kwargs["connection_string"] = connection_string + + return cls.from_texts( + texts=texts, + pre_delete_collection=pre_delete_collection, + embedding=embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + **kwargs, + )