community[minor]: Add ApertureDB as a vectorstore (#24088)

Thank you for contributing to LangChain!

- [X] *ApertureDB as vectorstore**: "community: Add ApertureDB as a
vectorestore"

- **Description:** this change provides a new community integration that
uses ApertureData's ApertureDB as a vector store.
    - **Issue:** none
    - **Dependencies:** depends on ApertureDB Python SDK
    - **Twitter handle:** ApertureData

- [X] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

Integration tests rely on a local run of a public docker image.
Example notebook additionally relies on a local Ollama server.

- [X] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

All lint tests pass.

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: Gautam <gautam@aperturedata.io>
pull/24313/head
bovlb 2 months ago committed by GitHub
parent c59e663365
commit 5caa381177
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -5,10 +5,10 @@ services:
dockerfile: libs/langchain/dev.Dockerfile
context: ..
volumes:
# Update this to wherever you want VS Code to mount the folder of your project
# Update this to wherever you want VS Code to mount the folder of your project
- ..:/workspaces/langchain:cached
networks:
- langchain-network
- langchain-network
# environment:
# MONGO_ROOT_USERNAME: root
# MONGO_ROOT_PASSWORD: example123
@ -28,5 +28,3 @@ services:
networks:
langchain-network:
driver: bridge

@ -0,0 +1,310 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "683953b3",
"metadata": {},
"source": [
"# ApertureDB\n",
"\n",
"[ApertureDB](https://docs.aperturedata.io) is a database that stores, indexes, and manages multi-modal data like text, images, videos, bounding boxes, and embeddings, together with their associated metadata.\n",
"\n",
"This notebook explains how to use the embeddings functionality of ApertureDB."
]
},
{
"cell_type": "markdown",
"id": "e7393beb",
"metadata": {},
"source": [
"## Install ApertureDB Python SDK\n",
"\n",
"This installs the [Python SDK](https://docs.aperturedata.io/category/aperturedb-python-sdk) used to write client code for ApertureDB."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a62cff8a-bcf7-4e33-bbbc-76999c2e3e20",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade --quiet aperturedb"
]
},
{
"cell_type": "markdown",
"id": "4fe12f77",
"metadata": {},
"source": [
"## Run an ApertureDB instance\n",
"\n",
"To continue, you should have an [ApertureDB instance up and running](https://docs.aperturedata.io/HowToGuides/start/Setup) and configure your environment to use it. \n",
"There are various ways to do that, for example:\n",
"\n",
"```bash\n",
"docker run --publish 55555:55555 aperturedata/aperturedb-standalone\n",
"adb config create local --active --no-interactive\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "667eabca",
"metadata": {},
"source": [
"## Download some web documents\n",
"We're going to do a mini-crawl here of one web page."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0798dfdb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"USER_AGENT environment variable not set, consider setting it to identify your requests.\n"
]
}
],
"source": [
"# For loading documents from web\n",
"from langchain_community.document_loaders import WebBaseLoader\n",
"\n",
"loader = WebBaseLoader(\"https://docs.aperturedata.io\")\n",
"docs = loader.load()"
]
},
{
"cell_type": "markdown",
"id": "5f077d11",
"metadata": {},
"source": [
"## Select embeddings model\n",
"\n",
"We want to use OllamaEmbeddings so we have to import the necessary modules.\n",
"\n",
"Ollama can be set up as a docker container as described in the [documentation](https://hub.docker.com/r/ollama/ollama), for example:\n",
"```bash\n",
"# Run server\n",
"docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama\n",
"# Tell server to load a specific model\n",
"docker exec ollama ollama run llama2\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8b6ed9cd-81b9-46e5-9c20-5aafca2844d0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain_community.embeddings import OllamaEmbeddings\n",
"\n",
"embeddings = OllamaEmbeddings()"
]
},
{
"cell_type": "markdown",
"id": "b7b313e6",
"metadata": {},
"source": [
"## Split documents into segments\n",
"\n",
"We want to turn our single document into multiple segments."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3c4b7b31",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"text_splitter = RecursiveCharacterTextSplitter()\n",
"documents = text_splitter.split_documents(docs)"
]
},
{
"cell_type": "markdown",
"id": "46339d32",
"metadata": {},
"source": [
"## Create vectorstore from documents and embeddings\n",
"\n",
"This code creates a vectorstore in the ApertureDB instance.\n",
"Within the instance, this vectorstore is represented as a \"[descriptor set](https://docs.aperturedata.io/category/descriptorset-commands)\".\n",
"By default, the descriptor set is named `langchain`. The following code will generate embeddings for each document and store them in ApertureDB as descriptors. This will take a few seconds as the embeddings are bring generated."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dcf88bdf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain_community.vectorstores import ApertureDB\n",
"\n",
"vector_db = ApertureDB.from_documents(documents, embeddings)"
]
},
{
"cell_type": "markdown",
"id": "7672877b",
"metadata": {},
"source": [
"## Select a large language model\n",
"\n",
"Again, we use the Ollama server we set up for local processing."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9a005e4b",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import Ollama\n",
"\n",
"llm = Ollama(model=\"llama2\")"
]
},
{
"cell_type": "markdown",
"id": "cd54f2ad",
"metadata": {},
"source": [
"## Build a RAG chain\n",
"\n",
"Now we have all the components we need to create a RAG (Retrieval-Augmented Generation) chain. This chain does the following:\n",
"1. Generate embedding descriptor for user query\n",
"2. Find text segments that are similar to the user query using the vector store\n",
"3. Pass user query and context documents to the LLM using a prompt template\n",
"4. Return the LLM's answer"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a8c513ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the provided context, ApertureDB can store images. In fact, it is specifically designed to manage multimodal data such as images, videos, documents, embeddings, and associated metadata including annotations. So, ApertureDB has the capability to store and manage images.\n"
]
}
],
"source": [
"# Create prompt\n",
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_template(\"\"\"Answer the following question based only on the provided context:\n",
"\n",
"<context>\n",
"{context}\n",
"</context>\n",
"\n",
"Question: {input}\"\"\")\n",
"\n",
"\n",
"# Create a chain that passes documents to an LLM\n",
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
"\n",
"document_chain = create_stuff_documents_chain(llm, prompt)\n",
"\n",
"\n",
"# Treat the vectorstore as a document retriever\n",
"retriever = vector_db.as_retriever()\n",
"\n",
"\n",
"# Create a RAG chain that connects the retriever to the LLM\n",
"from langchain.chains import create_retrieval_chain\n",
"\n",
"retrieval_chain = create_retrieval_chain(retriever, document_chain)"
]
},
{
"cell_type": "markdown",
"id": "3bc6a882",
"metadata": {},
"source": [
"## Run the RAG chain\n",
"\n",
"Finally we pass a question to the chain and get our answer. This will take a few seconds to run as the LLM generates an answer from the query and context documents."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "020f29f1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Based on the provided context, ApertureDB can store images in several ways:\n",
"\n",
"1. Multimodal data management: ApertureDB offers a unified interface to manage multimodal data such as images, videos, documents, embeddings, and associated metadata including annotations. This means that images can be stored along with other types of data in a single database instance.\n",
"2. Image storage: ApertureDB provides image storage capabilities through its integration with the public cloud providers or on-premise installations. This allows customers to host their own ApertureDB instances and store images on their preferred cloud provider or on-premise infrastructure.\n",
"3. Vector database: ApertureDB also offers a vector database that enables efficient similarity search and classification of images based on their semantic meaning. This can be useful for applications where image search and classification are important, such as in computer vision or machine learning workflows.\n",
"\n",
"Overall, ApertureDB provides flexible and scalable storage options for images, allowing customers to choose the deployment model that best suits their needs.\n"
]
}
],
"source": [
"user_query = \"How can ApertureDB store images?\"\n",
"response = retrieval_chain.invoke({\"input\": user_query})\n",
"print(response[\"answer\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -43,6 +43,9 @@ if TYPE_CHECKING:
from langchain_community.vectorstores.apache_doris import (
ApacheDoris,
)
from langchain_community.vectorstores.aperturedb import (
ApertureDB,
)
from langchain_community.vectorstores.astradb import (
AstraDB,
)
@ -311,6 +314,7 @@ __all__ = [
"AnalyticDB",
"Annoy",
"ApacheDoris",
"ApertureDB",
"AstraDB",
"AtlasDB",
"AwaDB",
@ -413,6 +417,7 @@ _module_lookup = {
"AnalyticDB": "langchain_community.vectorstores.analyticdb",
"Annoy": "langchain_community.vectorstores.annoy",
"ApacheDoris": "langchain_community.vectorstores.apache_doris",
"ApertureDB": "langchain_community.vectorstores.aperturedb",
"AstraDB": "langchain_community.vectorstores.astradb",
"AtlasDB": "langchain_community.vectorstores.atlas",
"AwaDB": "langchain_community.vectorstores.awadb",

@ -0,0 +1,516 @@
# System imports
from __future__ import annotations
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
# Third-party imports
import numpy as np
# Local imports
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.indexing.base import UpsertResponse
from langchain_core.vectorstores import VectorStore
from typing_extensions import override
# Configure some defaults
ENGINE = "HNSW"
METRIC = "CS"
DESCRIPTOR_SET = "langchain"
BATCHSIZE = 1000
PROPERTY_PREFIX = "lc_" # Prefix for properties that are in the client metadata
TEXT_PROPERTY = "text" # Property name for the text
UNIQUEID_PROPERTY = "uniqueid" # Property name for the unique id
class ApertureDB(VectorStore):
@override
def __init__(
self,
embeddings: Embeddings,
descriptor_set: str = DESCRIPTOR_SET,
dimensions: Optional[int] = None,
engine: Optional[str] = None,
metric: Optional[str] = None,
log_level: int = logging.WARN,
properties: Optional[Dict] = None,
**kwargs: Any,
) -> None:
"""Create a vectorstore backed by ApertureDB
A single ApertureDB instance can support many vectorstores,
distinguished by 'descriptor_set' name. The descriptor set is created
if it does not exist. Different descriptor sets can use different
engines and metrics, be supplied by different embedding models, and have
different dimensions.
See ApertureDB documentation on `AddDescriptorSet`
https://docs.aperturedata.io/query_language/Reference/descriptor_commands/desc_set_commands/AddDescriptorSet
for more information on the engine and metric options.
Args:
embeddings (Embeddings): Embeddings object
descriptor_set (str, optional): Descriptor set name. Defaults to
"langchain".
dimensions (Optional[int], optional): Number of dimensions of the
embeddings. Defaults to None.
engine (str, optional): Engine to use. Defaults to "HNSW" for new
descriptorsets.
metric (str, optional): Metric to use. Defaults to "CS" for new
descriptorsets.
log_level (int, optional): Logging level. Defaults to logging.WARN.
"""
# ApertureDB imports
try:
from aperturedb.Utils import Utils, create_connector
except ImportError:
raise ImportError(
"ApertureDB is not installed. Please install it using "
"'pip install aperturedb'"
)
super().__init__(**kwargs)
self.logger = logging.getLogger(__name__)
self.logger.setLevel(log_level)
self.descriptor_set = descriptor_set
self.embedding_function = embeddings
self.dimensions = dimensions
self.engine = engine
self.metric = metric
self.properties = properties
if embeddings is None:
self.logger.fatal("No embedding function provided.")
raise ValueError("No embedding function provided.")
try:
from aperturedb.Utils import Utils, create_connector
except ImportError:
self.logger.exception(
"ApertureDB is not installed. Please install it using "
"'pip install aperturedb'"
)
raise
self.connection = create_connector()
self.utils = Utils(self.connection)
try:
self.utils.status()
except Exception:
self.logger.exception("Failed to connect to ApertureDB")
raise
self._find_or_add_descriptor_set()
def _find_or_add_descriptor_set(self) -> None:
descriptor_set = self.descriptor_set
"""Checks if the descriptor set exists, if not, creates it"""
find_ds_query = [
{
"FindDescriptorSet": {
"with_name": descriptor_set,
"engines": True,
"metrics": True,
"dimensions": True,
"results": {"all_properties": True},
}
}
]
r, b = self.connection.query(find_ds_query)
assert self.connection.last_query_ok(), r
n_entities = (
len(r[0]["FindDescriptorSet"]["entities"])
if "entities" in r[0]["FindDescriptorSet"]
else 0
)
assert n_entities <= 1, "Multiple descriptor sets with the same name"
if n_entities == 1: # Descriptor set exists already
e = r[0]["FindDescriptorSet"]["entities"][0]
self.logger.info(f"Descriptor set {descriptor_set} already exists")
engines = e["_engines"]
assert len(engines) == 1, "Only one engine is supported"
if self.engine is None:
self.engine = engines[0]
elif self.engine != engines[0]:
self.logger.error(f"Engine mismatch: {self.engine} != {engines[0]}")
metrics = e["_metrics"]
assert len(metrics) == 1, "Only one metric is supported"
if self.metric is None:
self.metric = metrics[0]
elif self.metric != metrics[0]:
self.logger.error(f"Metric mismatch: {self.metric} != {metrics[0]}")
dimensions = e["_dimensions"]
if self.dimensions is None:
self.dimensions = dimensions
elif self.dimensions != dimensions:
self.logger.error(
f"Dimensions mismatch: {self.dimensions} != {dimensions}"
)
self.properties = {
k[len(PROPERTY_PREFIX) :]: v
for k, v in e.items()
if k.startswith(PROPERTY_PREFIX)
}
else:
self.logger.info(
f"Descriptor set {descriptor_set} does not exist. Creating it"
)
if self.engine is None:
self.engine = ENGINE
if self.metric is None:
self.metric = METRIC
if self.dimensions is None:
self.dimensions = len(self.embedding_function.embed_query("test"))
properties = (
{PROPERTY_PREFIX + k: v for k, v in self.properties.items()}
if self.properties is not None
else None
)
self.utils.add_descriptorset(
name=descriptor_set,
dim=self.dimensions,
engine=self.engine,
metric=self.metric,
properties=properties,
)
# Create indexes
self.utils.create_entity_index("_Descriptor", "_create_txn")
self.utils.create_entity_index("_DescriptorSet", "_name")
self.utils.create_entity_index("_Descriptor", UNIQUEID_PROPERTY)
@override
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete documents from the vectorstore by id.
Args:
ids: List of ids to delete from the vectorstore.
Returns:
True if the deletion was successful, False otherwise
"""
assert ids is not None, "ids must be provided"
query = [
{
"DeleteDescriptor": {
"set": self.descriptor_set,
"constraints": {UNIQUEID_PROPERTY: ["in", ids]},
}
}
]
result, _ = self.utils.execute(query)
return result
@override
def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
"""Find documents in the vectorstore by id.
Args:
ids: List of ids to find in the vectorstore.
Returns:
documents: List of Document objects found in the vectorstore.
"""
query = [
{
"FindDescriptor": {
"set": self.descriptor_set,
"constraints": {UNIQUEID_PROPERTY: ["in", ids]},
"results": {"all_properties": True},
}
}
]
results, _ = self.utils.execute(query)
docs = [
self._descriptor_to_document(d)
for d in results[0]["FindDescriptor"].get("entities", [])
]
return docs
@override
def similarity_search(
self, query: str, k: int = 4, *args: Any, **kwargs: Any
) -> List[Document]:
"""Search for documents similar to the query using the vectorstore
Args:
query: Query string to search for.
k: Number of results to return.
Returns:
List of Document objects ordered by decreasing similarity to the query.
"""
assert self.embedding_function is not None, "Embedding function is not set"
embedding = self.embedding_function.embed_query(query)
return self.similarity_search_by_vector(embedding, k, *args, **kwargs)
@override
def similarity_search_with_score(
self, query: str, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]:
embedding = self.embedding_function.embed_query(query)
return self._similarity_search_with_score_by_vector(embedding, *args, **kwargs)
def _descriptor_to_document(self, d: dict) -> Document:
metadata = {}
for k, v in d.items():
if k.startswith(PROPERTY_PREFIX):
metadata[k[len(PROPERTY_PREFIX) :]] = v
text = d[TEXT_PROPERTY]
uniqueid = d[UNIQUEID_PROPERTY]
doc = Document(page_content=text, metadata=metadata, id=uniqueid)
return doc
def _similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4, vectors: bool = False
) -> List[Tuple[Document, float]]:
from aperturedb.Descriptors import Descriptors
descriptors = Descriptors(self.connection)
start_time = time.time()
descriptors.find_similar(
set=self.descriptor_set, vector=embedding, k_neighbors=k, distances=True
)
self.logger.info(
f"ApertureDB similarity search took {time.time() - start_time} seconds"
)
return [(self._descriptor_to_document(d), d["_distance"]) for d in descriptors]
@override
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
"""Returns the k most similar documents to the given embedding vector
Args:
embedding: The embedding vector to search for
k: The number of similar documents to return
Returns:
List of Document objects ordered by decreasing similarity to the query.
"""
from aperturedb.Descriptors import Descriptors
descriptors = Descriptors(self.connection)
start_time = time.time()
descriptors.find_similar(
set=self.descriptor_set, vector=embedding, k_neighbors=k
)
self.logger.info(
f"ApertureDB similarity search took {time.time() - start_time} seconds"
)
return [self._descriptor_to_document(d) for d in descriptors]
@override
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Returns similar documents to the query that also have diversity
This algorithm balances relevance and diversity in the search results.
Args:
query: Query string to search for.
k: Number of results to return.
fetch_k: Number of results to fetch.
lambda_mult: Lambda multiplier for MMR.
Returns:
List of Document objects ordered by decreasing similarity/diversty.
"""
self.logger.info(f"Max Marginal Relevance search for query: {query}")
embedding = self.embedding_function.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult, **kwargs
)
@override
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Returns similar documents to the vector that also have diversity
This algorithm balances relevance and diversity in the search results.
Args:
embedding: Embedding vector to search for.
k: Number of results to return.
fetch_k: Number of results to fetch.
lambda_mult: Lambda multiplier for MMR.
Returns:
List of Document objects ordered by decreasing similarity/diversty.
"""
from aperturedb.Descriptors import Descriptors
descriptors = Descriptors(self.connection)
start_time = time.time()
descriptors.find_similar_mmr(
set=self.descriptor_set,
vector=embedding,
k_neighbors=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
)
self.logger.info(
f"ApertureDB similarity search mmr took {time.time() - start_time} seconds"
)
return [self._descriptor_to_document(d) for d in descriptors]
@classmethod
@override
def from_texts(
cls: Type[ApertureDB],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> ApertureDB:
"""Creates a new vectorstore from a list of texts
Args:
texts: List of text strings
embedding: Embeddings object as for constructing the vectorstore
metadatas: Optional list of metadatas associated with the texts.
**kwargs: Additional arguments to pass to the constructor
"""
store = cls(embeddings=embedding, **kwargs)
store.add_texts(texts, metadatas)
return store
@classmethod
@override
def from_documents(
cls: Type[ApertureDB],
documents: List[Document],
embedding: Embeddings,
**kwargs: Any,
) -> ApertureDB:
"""Creates a new vectorstore from a list of documents
Args:
documents: List of Document objects
embedding: Embeddings object as for constructing the vectorstore
metadatas: Optional list of metadatas associated with the texts.
**kwargs: Additional arguments to pass to the constructor
"""
store = cls(embeddings=embedding, **kwargs)
store.add_documents(documents)
return store
@classmethod
def delete_vectorstore(class_, descriptor_set: str) -> None:
"""Deletes a vectorstore and all its data from the database
Args:
descriptor_set: The name of the descriptor set to delete
"""
from aperturedb.Utils import Utils, create_connector
db = create_connector()
utils = Utils(db)
utils.remove_descriptorset(descriptor_set)
@classmethod
def list_vectorstores(class_) -> None:
"""Returns a list of all vectorstores in the database
Returns:
List of descriptor sets with properties
"""
from aperturedb.Utils import create_connector
db = create_connector()
query = [
{
"FindDescriptorSet": {
# Return all properties
"results": {"all_properties": True},
"engines": True,
"metrics": True,
"dimensions": True,
}
}
]
response, _ = db.query(query)
assert db.last_query_ok(), response
return response[0]["FindDescriptorSet"]["entities"]
@override
def upsert(self, items: Sequence[Document], /, **kwargs: Any) -> UpsertResponse:
"""Insert or update items
Updating documents is dependent on the documents' `id` attribute.
Args:
items: List of Document objects to upsert
Returns:
UpsertResponse object with succeeded and failed
"""
# For now, simply delete and add
# We could do something more efficient to update metadata,
# but we don't support changing the embedding of a descriptor.
from aperturedb.ParallelLoader import ParallelLoader
ids_to_delete: List[str] = [
item.id for item in items if hasattr(item, "id") and item.id is not None
]
if ids_to_delete:
self.delete(ids_to_delete)
texts = [doc.page_content for doc in items]
metadatas = [
doc.metadata if getattr(doc, "metadata", None) is not None else {}
for doc in items
]
embeddings = self.embedding_function.embed_documents(texts)
ids: List[str] = [
doc.id if hasattr(doc, "id") and doc.id is not None else str(uuid.uuid4())
for doc in items
]
data = []
for text, embedding, metadata, unique_id in zip(
texts, embeddings, metadatas, ids
):
properties = {PROPERTY_PREFIX + k: v for k, v in metadata.items()}
properties[TEXT_PROPERTY] = text
properties[UNIQUEID_PROPERTY] = unique_id
command = {
"AddDescriptor": {
"set": self.descriptor_set,
"properties": properties,
}
}
query = [command]
blobs = [np.array(embedding, dtype=np.float32).tobytes()]
data.append((query, blobs))
loader = ParallelLoader(self.connection)
loader.ingest(data, batchsize=BATCHSIZE)
return UpsertResponse(succeeded=ids, failed=[])

@ -0,0 +1,7 @@
services:
aperturedb:
image: aperturedata/aperturedb-standalone:latest
restart: on-failure:0
container_name: aperturedb
ports:
- 55555:55555

@ -0,0 +1,29 @@
"""Test ApertureDB functionality."""
import uuid
import pytest
from langchain_standard_tests.integration_tests.vectorstores import (
AsyncReadWriteTestSuite,
ReadWriteTestSuite,
)
from langchain_community.vectorstores import ApertureDB
class TestApertureDBReadWriteTestSuite(ReadWriteTestSuite):
@pytest.fixture
def vectorstore(self) -> ApertureDB:
descriptor_set = uuid.uuid4().hex # Fresh descriptor set for each test
return ApertureDB(
embeddings=self.get_embeddings(), descriptor_set=descriptor_set
)
class TestAsyncApertureDBReadWriteTestSuite(AsyncReadWriteTestSuite):
@pytest.fixture
async def vectorstore(self) -> ApertureDB:
descriptor_set = uuid.uuid4().hex # Fresh descriptor set for each test
return ApertureDB(
embeddings=self.get_embeddings(), descriptor_set=descriptor_set
)

@ -10,6 +10,7 @@ EXPECTED_ALL = [
"AnalyticDB",
"Annoy",
"ApacheDoris",
"ApertureDB",
"AstraDB",
"AtlasDB",
"AwaDB",

@ -48,6 +48,7 @@ def test_compatible_vectorstore_documentation() -> None:
documented = {
"Aerospike",
"AnalyticDB",
"ApertureDB",
"AstraDB",
"AzureCosmosDBVectorSearch",
"AzureCosmosDBNoSqlVectorSearch",

Loading…
Cancel
Save