forked from Archives/langchain
Add texts with embeddings to PGVector wrapper (#5500)
Similar to #1813 for faiss, this PR is to extend functionality to pass text and its vector pair to initialize and add embeddings to the PGVector wrapper. Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @dev2049
This commit is contained in:
parent
8d07ba0d51
commit
3bae595182
@ -192,6 +192,72 @@ class PGVector(VectorStore):
|
|||||||
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
|
def get_collection(self, session: Session) -> Optional["CollectionStore"]:
|
||||||
return CollectionStore.get_by_name(session, self.collection_name)
|
return CollectionStore.get_by_name(session, self.collection_name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __from(
|
||||||
|
cls,
|
||||||
|
texts: List[str],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
||||||
|
pre_delete_collection: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> PGVector:
|
||||||
|
if ids is None:
|
||||||
|
ids = [str(uuid.uuid1()) for _ in texts]
|
||||||
|
|
||||||
|
if not metadatas:
|
||||||
|
metadatas = [{} for _ in texts]
|
||||||
|
|
||||||
|
connection_string = cls.get_connection_string(kwargs)
|
||||||
|
|
||||||
|
store = cls(
|
||||||
|
connection_string=connection_string,
|
||||||
|
collection_name=collection_name,
|
||||||
|
embedding_function=embedding,
|
||||||
|
distance_strategy=distance_strategy,
|
||||||
|
pre_delete_collection=pre_delete_collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
store.add_embeddings(
|
||||||
|
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return store
|
||||||
|
|
||||||
|
def add_embeddings(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
embeddings: List[List[float]],
|
||||||
|
metadatas: List[dict],
|
||||||
|
ids: List[str],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Add embeddings to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of strings to add to the vectorstore.
|
||||||
|
embeddings: List of list of embedding vectors.
|
||||||
|
metadatas: List of metadatas associated with the texts.
|
||||||
|
kwargs: vectorstore specific parameters
|
||||||
|
"""
|
||||||
|
with Session(self._conn) as session:
|
||||||
|
collection = self.get_collection(session)
|
||||||
|
if not collection:
|
||||||
|
raise ValueError("Collection not found")
|
||||||
|
for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
||||||
|
embedding_store = EmbeddingStore(
|
||||||
|
embedding=embedding,
|
||||||
|
document=text,
|
||||||
|
cmetadata=metadata,
|
||||||
|
custom_id=id,
|
||||||
|
)
|
||||||
|
collection.embeddings.append(embedding_store)
|
||||||
|
session.add(embedding_store)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -380,19 +446,64 @@ class PGVector(VectorStore):
|
|||||||
"Either pass it as a parameter
|
"Either pass it as a parameter
|
||||||
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
||||||
"""
|
"""
|
||||||
|
embeddings = embedding.embed_documents(list(texts))
|
||||||
|
|
||||||
connection_string = cls.get_connection_string(kwargs)
|
return cls.__from(
|
||||||
|
texts,
|
||||||
store = cls(
|
embeddings,
|
||||||
connection_string=connection_string,
|
embedding,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embedding_function=embedding,
|
|
||||||
distance_strategy=distance_strategy,
|
distance_strategy=distance_strategy,
|
||||||
pre_delete_collection=pre_delete_collection,
|
pre_delete_collection=pre_delete_collection,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
store.add_texts(texts=texts, metadatas=metadatas, ids=ids, **kwargs)
|
@classmethod
|
||||||
return store
|
def from_embeddings(
|
||||||
|
cls,
|
||||||
|
text_embeddings: List[Tuple[str, List[float]]],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||||
|
distance_strategy: DistanceStrategy = DistanceStrategy.COSINE,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
pre_delete_collection: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> PGVector:
|
||||||
|
"""Construct PGVector wrapper from raw documents and pre-
|
||||||
|
generated embeddings.
|
||||||
|
|
||||||
|
Return VectorStore initialized from documents and embeddings.
|
||||||
|
Postgres connection string is required
|
||||||
|
"Either pass it as a parameter
|
||||||
|
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain import PGVector
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
text_embeddings = embeddings.embed_documents(texts)
|
||||||
|
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||||
|
faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings)
|
||||||
|
"""
|
||||||
|
texts = [t[0] for t in text_embeddings]
|
||||||
|
embeddings = [t[1] for t in text_embeddings]
|
||||||
|
|
||||||
|
return cls.__from(
|
||||||
|
texts,
|
||||||
|
embeddings,
|
||||||
|
embedding,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids,
|
||||||
|
collection_name=collection_name,
|
||||||
|
distance_strategy=distance_strategy,
|
||||||
|
pre_delete_collection=pre_delete_collection,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
|
def get_connection_string(cls, kwargs: Dict[str, Any]) -> str:
|
||||||
|
@ -49,6 +49,22 @@ def test_pgvector() -> None:
|
|||||||
assert output == [Document(page_content="foo")]
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pgvector_embeddings() -> None:
|
||||||
|
"""Test end to end construction with embeddings and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts)
|
||||||
|
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||||
|
docsearch = PGVector.from_embeddings(
|
||||||
|
text_embeddings=text_embedding_pairs,
|
||||||
|
collection_name="test_collection",
|
||||||
|
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||||
|
connection_string=CONNECTION_STRING,
|
||||||
|
pre_delete_collection=True,
|
||||||
|
)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
def test_pgvector_with_metadatas() -> None:
|
def test_pgvector_with_metadatas() -> None:
|
||||||
"""Test end to end construction and search."""
|
"""Test end to end construction and search."""
|
||||||
texts = ["foo", "bar", "baz"]
|
texts = ["foo", "bar", "baz"]
|
||||||
|
Loading…
Reference in New Issue
Block a user