astradb: move to langchain-datastax repo (#18354)

pull/18463/head
Erick Friis 3 months ago committed by GitHub
parent b641be2edf
commit 6afb135baa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -49,9 +49,11 @@ if __name__ == "__main__":
dirs_to_run["extended-test"].add(dir_)
elif file.startswith("libs/partners"):
partner_dir = file.split("/")[2]
if os.path.isdir(f"libs/partners/{partner_dir}") and os.listdir(
f"libs/partners/{partner_dir}"
) != ["README.md"]:
if os.path.isdir(f"libs/partners/{partner_dir}") and [
filename
for filename in os.listdir(f"libs/partners/{partner_dir}")
if not filename.startswith(".")
] != ["README.md"]:
dirs_to_run["test"].add(f"libs/partners/{partner_dir}")
# Skip if the directory was deleted or is just a tombstone readme
elif file.startswith("libs/"):

@ -20,11 +20,19 @@ jobs:
with:
repository: langchain-ai/langchain-google
path: langchain-google
- uses: actions/checkout@v4
with:
repository: langchain-ai/langchain-datastax
path: langchain-datastax
- name: Move google libs
run: |
rm -rf langchain/libs/partners/google-genai langchain/libs/partners/google-vertexai
rm -rf \
langchain/libs/partners/google-genai \
langchain/libs/partners/google-vertexai \
langchain/libs/partners/astradb
mv langchain-google/libs/genai langchain/libs/partners/google-genai
mv langchain-google/libs/vertexai langchain/libs/partners/google-vertexai
mv langchain-datastax/libs/astradb langchain/libs/partners/astradb
- name: Set Git config
working-directory: langchain

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -1,66 +0,0 @@
SHELL := /bin/bash
.PHONY: all format lint test tests integration_test integration_tests spell_check help
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
INTEGRATION_TEST_FILE ?= tests/integration_tests/
test:
poetry run pytest $(TEST_FILE)
tests:
poetry run pytest $(TEST_FILE)
integration_test:
poetry run pytest $(INTEGRATION_TEST_FILE)
integration_tests:
poetry run pytest $(INTEGRATION_TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/astradb --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_astradb
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_astradb -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

@ -1,68 +1,3 @@
# langchain-astradb
This package has moved!
This package contains the LangChain integrations for using DataStax Astra DB.
> DataStax [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) is a serverless vector-capable database built on Apache Cassandra® and made conveniently available
> through an easy-to-use JSON API.
_**Note.** For a short transitional period, only some of the Astra DB integration classes are contained in this package (the remaining ones being still in `langchain-community`). In a short while, and surely by version 0.2 of LangChain, all of the Astra DB support will be removed from `langchain-community` and included in this package._
## Installation and Setup
Installation of this partner package:
```bash
pip install langchain-astradb
```
## Integrations overview
### Vector Store
```python
from langchain_astradb import AstraDBVectorStore
my_store = AstraDBVectorStore(
embedding=my_embeddings,
collection_name="my_store",
api_endpoint="https://...",
token="AstraCS:...",
)
```
### Chat message history
```python
from langchain_astradb import AstraDBChatMessageHistory
message_history = AstraDBChatMessageHistory(
session_id="test-session",
api_endpoint="...",
token="...",
)
```
### Store
```python
from langchain_astradb import AstraDBStore
store = AstraDBStore(
collection_name="my_kv_store",
api_endpoint="...",
token="..."
)
```
### Byte Store
```python
from langchain_astradb import AstraDBByteStore
store = AstraDBByteStore(
collection_name="my_kv_store",
api_endpoint="...",
token="..."
)
```
## Reference
See the [LangChain docs page](https://python.langchain.com/docs/integrations/providers/astradb) for a more detailed listing.
https://github.com/langchain-ai/langchain-datastax/tree/main/libs/astradb

@ -1,10 +0,0 @@
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.vectorstores import AstraDBVectorStore
__all__ = [
"AstraDBByteStore",
"AstraDBStore",
"AstraDBChatMessageHistory",
"AstraDBVectorStore",
]

@ -1,148 +0,0 @@
"""Astra DB - based chat message history, based on astrapy."""
from __future__ import annotations
import json
import time
from typing import List, Optional, Sequence
from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
DEFAULT_COLLECTION_NAME = "langchain_message_store"
class AstraDBChatMessageHistory(BaseChatMessageHistory):
def __init__(
self,
*,
session_id: str,
collection_name: str = DEFAULT_COLLECTION_NAME,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
) -> None:
"""Chat message history that stores history in Astra DB.
Args:
session_id: arbitrary key that is used to store the messages
of a single chat session.
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
"""
self.astra_env = _AstraDBCollectionEnvironment(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
self.session_id = session_id
self.collection_name = collection_name
@property
def messages(self) -> List[BaseMessage]:
"""Retrieve all session messages from DB"""
self.astra_env.ensure_db_setup()
message_blobs = [
doc["body_blob"]
for doc in sorted(
self.collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
),
key=lambda _doc: _doc["timestamp"],
)
]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
raise NotImplementedError("Use add_messages instead")
async def aget_messages(self) -> List[BaseMessage]:
await self.astra_env.aensure_db_setup()
docs = self.async_collection.paginated_find(
filter={
"session_id": self.session_id,
},
projection={
"timestamp": 1,
"body_blob": 1,
},
)
sorted_docs = sorted(
[doc async for doc in docs],
key=lambda _doc: _doc["timestamp"],
)
message_blobs = [doc["body_blob"] for doc in sorted_docs]
items = [json.loads(message_blob) for message_blob in message_blobs]
messages = messages_from_dict(items)
return messages
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
self.astra_env.ensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
self.collection.chunked_insert_many(docs)
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
await self.astra_env.aensure_db_setup()
docs = [
{
"timestamp": time.time(),
"session_id": self.session_id,
"body_blob": json.dumps(message_to_dict(message)),
}
for message in messages
]
await self.async_collection.chunked_insert_many(docs)
def clear(self) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"session_id": self.session_id})
async def aclear(self) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"session_id": self.session_id})

@ -1,217 +0,0 @@
from __future__ import annotations
import base64
from abc import ABC, abstractmethod
from typing import (
Any,
AsyncIterator,
Generic,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from astrapy.db import AstraDB, AsyncAstraDB
from langchain_core.stores import BaseStore, ByteStore
from langchain_astradb.utils.astradb import (
SetupMode,
_AstraDBCollectionEnvironment,
)
V = TypeVar("V")
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
"""Base class for the DataStax AstraDB data store."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
self.collection = self.astra_env.collection
self.async_collection = self.astra_env.async_collection
@abstractmethod
def decode_value(self, value: Any) -> Optional[V]:
"""Decodes value from Astra DB"""
@abstractmethod
def encode_value(self, value: Optional[V]) -> Any:
"""Encodes value for Astra DB"""
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
self.astra_env.ensure_db_setup()
docs_dict = {}
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
await self.astra_env.aensure_db_setup()
docs_dict = {}
async for doc in self.async_collection.paginated_find(
filter={"_id": {"$in": list(keys)}}
):
docs_dict[doc["_id"]] = doc.get("value")
return [self.decode_value(docs_dict.get(key)) for key in keys]
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
self.astra_env.ensure_db_setup()
for k, v in key_value_pairs:
self.collection.upsert_one({"_id": k, "value": self.encode_value(v)})
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
await self.astra_env.aensure_db_setup()
for k, v in key_value_pairs:
await self.async_collection.upsert_one(
{"_id": k, "value": self.encode_value(v)}
)
def mdelete(self, keys: Sequence[str]) -> None:
self.astra_env.ensure_db_setup()
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
async def amdelete(self, keys: Sequence[str]) -> None:
await self.astra_env.aensure_db_setup()
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
self.astra_env.ensure_db_setup()
docs = self.collection.paginated_find()
for doc in docs:
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
await self.astra_env.aensure_db_setup()
async for doc in self.async_collection.paginated_find():
key = doc["_id"]
if not prefix or key.startswith(prefix):
yield key
class AstraDBStore(AstraDBBaseStore[Any]):
def __init__(
self,
collection_name: str,
*,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""BaseStore implementation using DataStax AstraDB as the underlying store.
The value type can be any type serializable by json.dumps.
Can be used to store embeddings with the CacheBackedEmbeddings.
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": <value>
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
def decode_value(self, value: Any) -> Any:
return value
def encode_value(self, value: Any) -> Any:
return value
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
def __init__(
self,
*,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
namespace: Optional[str] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
pre_delete_collection: bool = False,
setup_mode: SetupMode = SetupMode.SYNC,
) -> None:
"""ByteStore implementation using DataStax AstraDB as the underlying store.
The bytes values are converted to base64 encoded strings
Documents in the AstraDB collection will have the format
.. code-block:: json
{
"_id": "<key>",
"value": "<byte64 string value>"
}
Args:
collection_name: name of the Astra DB collection to create/use.
token: API token for Astra DB usage.
api_endpoint: full URL to the API endpoint,
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AstraDB' instance.
async_astra_db_client: *alternative to token+api_endpoint*,
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
namespace: namespace (aka keyspace) where the
collection is created. Defaults to the database's "default namespace".
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
OFF).
pre_delete_collection: whether to delete the collection
before creating it. If False and the collection already exists,
the collection will be used as is.
"""
super().__init__(
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
setup_mode=setup_mode,
pre_delete_collection=pre_delete_collection,
)
def decode_value(self, value: Any) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value)
def encode_value(self, value: Optional[bytes]) -> Any:
if value is None:
return None
return base64.b64encode(value).decode("ascii")

@ -1,152 +0,0 @@
from __future__ import annotations
import asyncio
import inspect
from asyncio import InvalidStateError, Task
from enum import Enum
from typing import Awaitable, Optional, Union
import langchain_core
from astrapy.db import AstraDB, AsyncAstraDB
class SetupMode(Enum):
SYNC = 1
ASYNC = 2
OFF = 3
class _AstraDBEnvironment:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
) -> None:
self.token = token
self.api_endpoint = api_endpoint
astra_db = astra_db_client
async_astra_db = async_astra_db_client
self.namespace = namespace
# Conflicting-arg checks:
if astra_db_client is not None or async_astra_db_client is not None:
if token is not None or api_endpoint is not None:
raise ValueError(
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
)
if token and api_endpoint:
astra_db = AstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
async_astra_db = AsyncAstraDB(
token=token,
api_endpoint=api_endpoint,
namespace=self.namespace,
)
if astra_db:
self.astra_db = astra_db.copy()
if async_astra_db:
self.async_astra_db = async_astra_db.copy()
else:
self.async_astra_db = self.astra_db.to_async()
elif async_astra_db:
self.async_astra_db = async_astra_db.copy()
self.astra_db = self.async_astra_db.to_sync()
else:
raise ValueError(
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
"'token' and 'api_endpoint'"
)
self.astra_db.set_caller(
caller_name="langchain",
caller_version=getattr(langchain_core, "__version__", None),
)
self.async_astra_db.set_caller(
caller_name="langchain",
caller_version=getattr(langchain_core, "__version__", None),
)
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
def __init__(
self,
collection_name: str,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
astra_db_client: Optional[AstraDB] = None,
async_astra_db_client: Optional[AsyncAstraDB] = None,
namespace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
pre_delete_collection: bool = False,
embedding_dimension: Union[int, Awaitable[int], None] = None,
metric: Optional[str] = None,
) -> None:
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
super().__init__(
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
)
self.collection_name = collection_name
self.collection = AstraDBCollection(
collection_name=collection_name,
astra_db=self.astra_db,
)
self.async_collection = AsyncAstraDBCollection(
collection_name=collection_name,
astra_db=self.async_astra_db,
)
self.async_setup_db_task: Optional[Task] = None
if setup_mode == SetupMode.ASYNC:
async_astra_db = self.async_astra_db
async def _setup_db() -> None:
if pre_delete_collection:
await async_astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
dimension = await embedding_dimension
else:
dimension = embedding_dimension
await async_astra_db.create_collection(
collection_name, dimension=dimension, metric=metric
)
self.async_setup_db_task = asyncio.create_task(_setup_db())
elif setup_mode == SetupMode.SYNC:
if pre_delete_collection:
self.astra_db.delete_collection(collection_name)
if inspect.isawaitable(embedding_dimension):
raise ValueError(
"Cannot use an awaitable embedding_dimension with async_setup "
"set to False"
)
self.astra_db.create_collection(
collection_name,
dimension=embedding_dimension, # type: ignore[arg-type]
metric=metric,
)
def ensure_db_setup(self) -> None:
if self.async_setup_db_task:
try:
self.async_setup_db_task.result()
except InvalidStateError:
raise ValueError(
"Asynchronous setup of the DB not finished. "
"NB: AstraDB components sync methods shouldn't be called from the "
"event loop. Consider using their async equivalents."
)
async def aensure_db_setup(self) -> None:
if self.async_setup_db_task:
await self.async_setup_db_task

@ -1,87 +0,0 @@
"""
Tools for the Maximal Marginal Relevance (MMR) reranking.
Duplicated from langchain_community to avoid cross-dependencies.
Functions "maximal_marginal_relevance" and "cosine_similarity"
are duplicated in this utility respectively from modules:
- "libs/community/langchain_community/vectorstores/utils.py"
- "libs/community/langchain_community/utils/math.py"
"""
import logging
from typing import List, Union
import numpy as np
logger = logging.getLogger(__name__)
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices."""
if len(X) == 0 or len(Y) == 0:
return np.array([])
X = np.array(X)
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
f"and Y has shape {Y.shape}."
)
try:
import simsimd as simd # type: ignore
X = np.array(X, dtype=np.float32)
Y = np.array(Y, dtype=np.float32)
Z = 1 - simd.cdist(X, Y, metric="cosine")
if isinstance(Z, float):
return np.array([Z])
return Z
except ImportError:
logger.info(
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
"to use simsimd please install with `pip install simsimd`."
)
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity
def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
"""Calculate maximal marginal relevance."""
if min(k, len(embedding_list)) <= 0:
return []
if query_embedding.ndim == 1:
query_embedding = np.expand_dims(query_embedding, axis=0)
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
most_similar = int(np.argmax(similarity_to_query))
idxs = [most_similar]
selected = np.array([embedding_list[most_similar]])
while len(idxs) < min(k, len(embedding_list)):
best_score = -np.inf
idx_to_add = -1
similarity_to_selected = cosine_similarity(embedding_list, selected)
for i, query_score in enumerate(similarity_to_query):
if i in idxs:
continue
redundant_score = max(similarity_to_selected[i])
equation_score = (
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
)
if equation_score > best_score:
best_score = equation_score
idx_to_add = i
idxs.append(idx_to_add)
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
return idxs

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,92 +0,0 @@
[tool.poetry]
name = "langchain-astradb"
version = "0.0.1"
description = "An integration package connecting Astra DB and LangChain"
authors = []
readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.5"
astrapy = "^0.7.5"
numpy = "^1"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
pytest-dotenv = "^0.5.2"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
langchain = { path = "../../langchain", develop = true }
langchain-community = { path = "../../community", develop = true }
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
[tool.ruff]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = ["tests/*"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

@ -1,17 +0,0 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_faillure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

@ -1,5 +0,0 @@
# astra db
ASTRA_DB_API_ENDPOINT=https://your_astra_db_id-your_region.apps.astra.datastax.com
ASTRA_DB_APPLICATION_TOKEN=AstraCS:your_astra_db_application_token
# ASTRA_DB_KEYSPACE=your_astra_db_namespace
# ASTRA_DB_SKIP_COLLECTION_DELETIONS=true

@ -1,19 +0,0 @@
# Getting the absolute path of the current file's directory
import os
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
# Loading the .env file if it exists
def _load_env() -> None:
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
if os.path.exists(dotenv_path):
from dotenv import load_dotenv
load_dotenv(dotenv_path)
_load_env()

@ -1,198 +0,0 @@
import os
from typing import AsyncIterable, Iterable
import pytest
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import AIMessage, HumanMessage
from langchain_astradb.chat_message_histories import (
AstraDBChatMessageHistory,
)
from langchain_astradb.utils.astradb import SetupMode
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture(scope="function")
def history1() -> Iterable[AstraDBChatMessageHistory]:
history1 = AstraDBChatMessageHistory(
session_id="session-test-1",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history1
history1.collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture(scope="function")
def history2() -> Iterable[AstraDBChatMessageHistory]:
history2 = AstraDBChatMessageHistory(
session_id="session-test-2",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
yield history2
history2.collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
history1 = AstraDBChatMessageHistory(
session_id="async-session-test-1",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield history1
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.fixture(scope="function")
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
history2 = AstraDBChatMessageHistory(
session_id="async-session-test-2",
collection_name="langchain_cmh_test",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
setup_mode=SetupMode.ASYNC,
)
yield history2
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
"""Test the memory with a message store."""
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=history1,
return_messages=True,
)
assert memory.chat_memory.messages == []
# add some messages
memory.chat_memory.add_messages(
[
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
)
messages = memory.chat_memory.messages
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
# clear the store
memory.chat_memory.clear()
assert memory.chat_memory.messages == []
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
async def test_memory_with_message_store_async(
async_history1: AstraDBChatMessageHistory,
) -> None:
"""Test the memory with a message store."""
memory = ConversationBufferMemory(
memory_key="baz",
chat_memory=async_history1,
return_messages=True,
)
assert await memory.chat_memory.aget_messages() == []
# add some messages
await memory.chat_memory.aadd_messages(
[
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
)
messages = await memory.chat_memory.aget_messages()
expected = [
AIMessage(content="This is me, the AI"),
HumanMessage(content="This is me, the human"),
]
assert messages == expected
# clear the store
await memory.chat_memory.aclear()
assert await memory.chat_memory.aget_messages() == []
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
def test_memory_separate_session_ids(
history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory
) -> None:
"""Test that separate session IDs do not share entries."""
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=history1,
return_messages=True,
)
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=history2,
return_messages=True,
)
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
assert memory2.chat_memory.messages == []
memory2.chat_memory.clear()
assert memory1.chat_memory.messages != []
memory1.chat_memory.clear()
assert memory1.chat_memory.messages == []
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
async def test_memory_separate_session_ids_async(
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
) -> None:
"""Test that separate session IDs do not share entries."""
memory1 = ConversationBufferMemory(
memory_key="mk1",
chat_memory=async_history1,
return_messages=True,
)
memory2 = ConversationBufferMemory(
memory_key="mk2",
chat_memory=async_history2,
return_messages=True,
)
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
assert await memory2.chat_memory.aget_messages() == []
await memory2.chat_memory.aclear()
assert await memory1.chat_memory.aget_messages() != []
await memory1.chat_memory.aclear()
assert await memory1.chat_memory.aget_messages() == []

@ -1,7 +0,0 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

@ -1,176 +0,0 @@
"""Implement integration tests for AstraDB storage."""
from __future__ import annotations
import os
import pytest
from astrapy.db import AstraDB, AsyncAstraDB
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
from langchain_astradb.utils.astradb import SetupMode
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture
def astra_db() -> AstraDB:
from astrapy.db import AstraDB
return AstraDB(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
@pytest.fixture
def async_astra_db() -> AsyncAstraDB:
from astrapy.db import AsyncAstraDB
return AsyncAstraDB(
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
)
def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore:
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
return store
def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore:
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", b"value1"), ("key2", b"value2")])
return store
async def init_async_store(
async_astra_db: AsyncAstraDB, collection_name: str
) -> AstraDBStore:
store = AstraDBStore(
collection_name=collection_name,
async_astra_db_client=async_astra_db,
setup_mode=SetupMode.ASYNC,
)
await store.amset([("key1", [0.1, 0.2]), ("key2", "value2")])
return store
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBStore:
def test_mget(self, astra_db: AstraDB) -> None:
"""Test AstraDBStore mget method."""
collection_name = "lc_test_store_mget"
try:
store = init_store(astra_db, collection_name)
assert store.mget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
finally:
astra_db.delete_collection(collection_name)
async def test_amget(self, async_astra_db: AsyncAstraDB) -> None:
"""Test AstraDBStore amget method."""
collection_name = "lc_test_store_mget"
try:
store = await init_async_store(async_astra_db, collection_name)
assert await store.amget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
finally:
await async_astra_db.delete_collection(collection_name)
def test_mset(self, astra_db: AstraDB) -> None:
"""Test that multiple keys can be set with AstraDBStore."""
collection_name = "lc_test_store_mset"
try:
store = init_store(astra_db, collection_name)
result = store.collection.find_one({"_id": "key1"})
assert result["data"]["document"]["value"] == [0.1, 0.2]
result = store.collection.find_one({"_id": "key2"})
assert result["data"]["document"]["value"] == "value2"
finally:
astra_db.delete_collection(collection_name)
async def test_amset(self, async_astra_db: AsyncAstraDB) -> None:
"""Test that multiple keys can be set with AstraDBStore."""
collection_name = "lc_test_store_mset"
try:
store = await init_async_store(async_astra_db, collection_name)
result = await store.async_collection.find_one({"_id": "key1"})
assert result["data"]["document"]["value"] == [0.1, 0.2]
result = await store.async_collection.find_one({"_id": "key2"})
assert result["data"]["document"]["value"] == "value2"
finally:
await async_astra_db.delete_collection(collection_name)
def test_mdelete(self, astra_db: AstraDB) -> None:
"""Test that deletion works as expected."""
collection_name = "lc_test_store_mdelete"
try:
store = init_store(astra_db, collection_name)
store.mdelete(["key1", "key2"])
result = store.mget(["key1", "key2"])
assert result == [None, None]
finally:
astra_db.delete_collection(collection_name)
async def test_amdelete(self, async_astra_db: AsyncAstraDB) -> None:
"""Test that deletion works as expected."""
collection_name = "lc_test_store_mdelete"
try:
store = await init_async_store(async_astra_db, collection_name)
await store.amdelete(["key1", "key2"])
result = await store.amget(["key1", "key2"])
assert result == [None, None]
finally:
await async_astra_db.delete_collection(collection_name)
def test_yield_keys(self, astra_db: AstraDB) -> None:
collection_name = "lc_test_store_yield_keys"
try:
store = init_store(astra_db, collection_name)
assert set(store.yield_keys()) == {"key1", "key2"}
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"}
assert set(store.yield_keys(prefix="lang")) == set()
finally:
astra_db.delete_collection(collection_name)
async def test_ayield_keys(self, async_astra_db: AsyncAstraDB) -> None:
collection_name = "lc_test_store_yield_keys"
try:
store = await init_async_store(async_astra_db, collection_name)
assert {key async for key in store.ayield_keys()} == {"key1", "key2"}
assert {key async for key in store.ayield_keys(prefix="key")} == {
"key1",
"key2",
}
assert {key async for key in store.ayield_keys(prefix="lang")} == set()
finally:
await async_astra_db.delete_collection(collection_name)
def test_bytestore_mget(self, astra_db: AstraDB) -> None:
"""Test AstraDBByteStore mget method."""
collection_name = "lc_test_bytestore_mget"
try:
store = init_bytestore(astra_db, collection_name)
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"]
finally:
astra_db.delete_collection(collection_name)
def test_bytestore_mset(self, astra_db: AstraDB) -> None:
"""Test that multiple keys can be set with AstraDBByteStore."""
collection_name = "lc_test_bytestore_mset"
try:
store = init_bytestore(astra_db, collection_name)
result = store.collection.find_one({"_id": "key1"})
assert result["data"]["document"]["value"] == "dmFsdWUx"
result = store.collection.find_one({"_id": "key2"})
assert result["data"]["document"]["value"] == "dmFsdWUy"
finally:
astra_db.delete_collection(collection_name)

@ -1,868 +0,0 @@
"""
Test of Astra DB vector store class `AstraDBVectorStore`
Required to run this test:
- a recent `astrapy` Python package available
- an Astra DB instance;
- the two environment variables set:
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
- optionally this as well (otherwise defaults are used):
export ASTRA_DB_KEYSPACE="my_keyspace"
- optionally:
export ASTRA_DB_SKIP_COLLECTION_DELETIONS="0" ("1" = no deletions, default)
"""
import json
import math
import os
from typing import Iterable, List, Optional, TypedDict
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_astradb.vectorstores import AstraDBVectorStore
# Faster testing (no actual collection deletions). Off by default (=full tests)
SKIP_COLLECTION_DELETE = (
int(os.environ.get("ASTRA_DB_SKIP_COLLECTION_DELETIONS", "0")) != 0
)
COLLECTION_NAME_DIM2 = "lc_test_d2"
COLLECTION_NAME_DIM2_EUCLIDEAN = "lc_test_d2_eucl"
MATCH_EPSILON = 0.0001
# Ad-hoc embedding classes:
class AstraDBCredentials(TypedDict):
token: str
api_endpoint: str
namespace: Optional[str]
class SomeEmbeddings(Embeddings):
"""
Turn a sentence into an embedding vector in some way.
Not important how. It is deterministic is all that counts.
"""
def __init__(self, dimension: int) -> None:
self.dimension = dimension
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(txt) for txt in texts]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
unnormed0 = [ord(c) for c in text[: self.dimension]]
unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[
: self.dimension
]
norm = sum(x * x for x in unnormed) ** 0.5
normed = [x / norm for x in unnormed]
return normed
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
class ParserEmbeddings(Embeddings):
"""
Parse input texts: if they are json for a List[float], fine.
Otherwise, return all zeros and call it a day.
"""
def __init__(self, dimension: int) -> None:
self.dimension = dimension
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(txt) for txt in texts]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
try:
vals = json.loads(text)
assert len(vals) == self.dimension
return vals
except Exception:
print(f'[ParserEmbeddings] Returning a moot vector for "{text}"')
return [0.0] * self.dimension
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
def _has_env_vars() -> bool:
return all(
[
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
"ASTRA_DB_API_ENDPOINT" in os.environ,
]
)
@pytest.fixture(scope="session")
def astradb_credentials() -> Iterable[AstraDBCredentials]:
yield {
"token": os.environ["ASTRA_DB_APPLICATION_TOKEN"],
"api_endpoint": os.environ["ASTRA_DB_API_ENDPOINT"],
"namespace": os.environ.get("ASTRA_DB_KEYSPACE"),
}
@pytest.fixture(scope="function")
def store_someemb(
astradb_credentials: AstraDBCredentials,
) -> Iterable[AstraDBVectorStore]:
emb = SomeEmbeddings(dimension=2)
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
v_store.clear()
yield v_store
if not SKIP_COLLECTION_DELETE:
v_store.delete_collection()
else:
v_store.clear()
@pytest.fixture(scope="function")
def store_parseremb(
astradb_credentials: AstraDBCredentials,
) -> Iterable[AstraDBVectorStore]:
emb = ParserEmbeddings(dimension=2)
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
v_store.clear()
yield v_store
if not SKIP_COLLECTION_DELETE:
v_store.delete_collection()
else:
v_store.clear()
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBVectorStore:
def test_astradb_vectorstore_create_delete(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""Create and delete."""
from astrapy.db import AstraDB as LibAstraDB
emb = SomeEmbeddings(dimension=2)
# creation by passing the connection secrets
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
v_store.add_texts("Sample 1")
if not SKIP_COLLECTION_DELETE:
v_store.delete_collection()
else:
v_store.clear()
# Creation by passing a ready-made astrapy client:
astra_db_client = LibAstraDB(
**astradb_credentials,
)
v_store_2 = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
astra_db_client=astra_db_client,
)
v_store_2.add_texts("Sample 2")
if not SKIP_COLLECTION_DELETE:
v_store_2.delete_collection()
else:
v_store_2.clear()
async def test_astradb_vectorstore_create_delete_async(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""Create and delete."""
emb = SomeEmbeddings(dimension=2)
# creation by passing the connection secrets
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
await v_store.adelete_collection()
# Creation by passing a ready-made astrapy client:
from astrapy.db import AsyncAstraDB
astra_db_client = AsyncAstraDB(
**astradb_credentials,
)
v_store_2 = AstraDBVectorStore(
embedding=emb,
collection_name="lc_test_2_async",
async_astra_db_client=astra_db_client,
)
if not SKIP_COLLECTION_DELETE:
await v_store_2.adelete_collection()
else:
await v_store_2.aclear()
@pytest.mark.skipif(
SKIP_COLLECTION_DELETE,
reason="Collection-deletion tests are suppressed",
)
def test_astradb_vectorstore_pre_delete_collection(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""Use of the pre_delete_collection flag."""
emb = SomeEmbeddings(dimension=2)
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
v_store.clear()
try:
v_store.add_texts(
texts=["aa"],
metadatas=[
{"k": "a", "ord": 0},
],
ids=["a"],
)
res1 = v_store.similarity_search("aa", k=5)
assert len(res1) == 1
v_store = AstraDBVectorStore(
embedding=emb,
pre_delete_collection=True,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
res1 = v_store.similarity_search("aa", k=5)
assert len(res1) == 0
finally:
v_store.delete_collection()
@pytest.mark.skipif(
SKIP_COLLECTION_DELETE,
reason="Collection-deletion tests are suppressed",
)
async def test_astradb_vectorstore_pre_delete_collection_async(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""Use of the pre_delete_collection flag."""
emb = SomeEmbeddings(dimension=2)
# creation by passing the connection secrets
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
try:
await v_store.aadd_texts(
texts=["aa"],
metadatas=[
{"k": "a", "ord": 0},
],
ids=["a"],
)
res1 = await v_store.asimilarity_search("aa", k=5)
assert len(res1) == 1
v_store = AstraDBVectorStore(
embedding=emb,
pre_delete_collection=True,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
res1 = await v_store.asimilarity_search("aa", k=5)
assert len(res1) == 0
finally:
await v_store.adelete_collection()
def test_astradb_vectorstore_from_x(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""from_texts and from_documents methods."""
emb = SomeEmbeddings(dimension=2)
# prepare empty collection
AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
).clear()
# from_texts
v_store = AstraDBVectorStore.from_texts(
texts=["Hi", "Ho"],
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
try:
assert v_store.similarity_search("Ho", k=1)[0].page_content == "Ho"
finally:
if not SKIP_COLLECTION_DELETE:
v_store.delete_collection()
else:
v_store.clear()
# from_documents
v_store_2 = AstraDBVectorStore.from_documents(
[
Document(page_content="Hee"),
Document(page_content="Hoi"),
],
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
try:
assert v_store_2.similarity_search("Hoi", k=1)[0].page_content == "Hoi"
finally:
if not SKIP_COLLECTION_DELETE:
v_store_2.delete_collection()
else:
v_store_2.clear()
async def test_astradb_vectorstore_from_x_async(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""from_texts and from_documents methods."""
emb = SomeEmbeddings(dimension=2)
# prepare empty collection
await AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
).aclear()
# from_texts
v_store = await AstraDBVectorStore.afrom_texts(
texts=["Hi", "Ho"],
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
try:
assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho"
finally:
if not SKIP_COLLECTION_DELETE:
await v_store.adelete_collection()
else:
await v_store.aclear()
# from_documents
v_store_2 = await AstraDBVectorStore.afrom_documents(
[
Document(page_content="Hee"),
Document(page_content="Hoi"),
],
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
)
try:
assert (await v_store_2.asimilarity_search("Hoi", k=1))[
0
].page_content == "Hoi"
finally:
if not SKIP_COLLECTION_DELETE:
await v_store_2.adelete_collection()
else:
await v_store_2.aclear()
def test_astradb_vectorstore_crud(self, store_someemb: AstraDBVectorStore) -> None:
"""Basic add/delete/update behaviour."""
res0 = store_someemb.similarity_search("Abc", k=2)
assert res0 == []
# write and check again
store_someemb.add_texts(
texts=["aa", "bb", "cc"],
metadatas=[
{"k": "a", "ord": 0},
{"k": "b", "ord": 1},
{"k": "c", "ord": 2},
],
ids=["a", "b", "c"],
)
res1 = store_someemb.similarity_search("Abc", k=5)
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
# partial overwrite and count total entries
store_someemb.add_texts(
texts=["cc", "dd"],
metadatas=[
{"k": "c_new", "ord": 102},
{"k": "d_new", "ord": 103},
],
ids=["c", "d"],
)
res2 = store_someemb.similarity_search("Abc", k=10)
assert len(res2) == 4
# pick one that was just updated and check its metadata
res3 = store_someemb.similarity_search_with_score_id(
query="cc", k=1, filter={"k": "c_new"}
)
print(str(res3))
doc3, score3, id3 = res3[0]
assert doc3.page_content == "cc"
assert doc3.metadata == {"k": "c_new", "ord": 102}
assert score3 > 0.999 # leaving some leeway for approximations...
assert id3 == "c"
# delete and count again
del1_res = store_someemb.delete(["b"])
assert del1_res is True
del2_res = store_someemb.delete(["a", "c", "Z!"])
assert del2_res is True # a non-existing ID was supplied
assert len(store_someemb.similarity_search("xy", k=10)) == 1
# clear store
store_someemb.clear()
assert store_someemb.similarity_search("Abc", k=2) == []
# add_documents with "ids" arg passthrough
store_someemb.add_documents(
[
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
],
ids=["v", "w"],
)
assert len(store_someemb.similarity_search("xy", k=10)) == 2
res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"})
assert res4[0].metadata["ord"] == 205
async def test_astradb_vectorstore_crud_async(
self, store_someemb: AstraDBVectorStore
) -> None:
"""Basic add/delete/update behaviour."""
res0 = await store_someemb.asimilarity_search("Abc", k=2)
assert res0 == []
# write and check again
await store_someemb.aadd_texts(
texts=["aa", "bb", "cc"],
metadatas=[
{"k": "a", "ord": 0},
{"k": "b", "ord": 1},
{"k": "c", "ord": 2},
],
ids=["a", "b", "c"],
)
res1 = await store_someemb.asimilarity_search("Abc", k=5)
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
# partial overwrite and count total entries
await store_someemb.aadd_texts(
texts=["cc", "dd"],
metadatas=[
{"k": "c_new", "ord": 102},
{"k": "d_new", "ord": 103},
],
ids=["c", "d"],
)
res2 = await store_someemb.asimilarity_search("Abc", k=10)
assert len(res2) == 4
# pick one that was just updated and check its metadata
res3 = await store_someemb.asimilarity_search_with_score_id(
query="cc", k=1, filter={"k": "c_new"}
)
print(str(res3))
doc3, score3, id3 = res3[0]
assert doc3.page_content == "cc"
assert doc3.metadata == {"k": "c_new", "ord": 102}
assert score3 > 0.999 # leaving some leeway for approximations...
assert id3 == "c"
# delete and count again
del1_res = await store_someemb.adelete(["b"])
assert del1_res is True
del2_res = await store_someemb.adelete(["a", "c", "Z!"])
assert del2_res is False # a non-existing ID was supplied
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1
# clear store
await store_someemb.aclear()
assert await store_someemb.asimilarity_search("Abc", k=2) == []
# add_documents with "ids" arg passthrough
await store_someemb.aadd_documents(
[
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
],
ids=["v", "w"],
)
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2
res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"})
assert res4[0].metadata["ord"] == 205
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDBVectorStore) -> None:
"""
MMR testing. We work on the unit circle with angle multiples
of 2*pi/20 and prepare a store with known vectors for a controlled
MMR outcome.
"""
def _v_from_i(i: int, N: int) -> str:
angle = 2 * math.pi * i / N
vector = [math.cos(angle), math.sin(angle)]
return json.dumps(vector)
i_vals = [0, 4, 5, 13]
N_val = 20
store_parseremb.add_texts(
[_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals]
)
res1 = store_parseremb.max_marginal_relevance_search(
_v_from_i(3, N_val),
k=2,
fetch_k=3,
)
res_i_vals = {doc.metadata["i"] for doc in res1}
assert res_i_vals == {0, 4}
async def test_astradb_vectorstore_mmr_async(
self, store_parseremb: AstraDBVectorStore
) -> None:
"""
MMR testing. We work on the unit circle with angle multiples
of 2*pi/20 and prepare a store with known vectors for a controlled
MMR outcome.
"""
def _v_from_i(i: int, N: int) -> str:
angle = 2 * math.pi * i / N
vector = [math.cos(angle), math.sin(angle)]
return json.dumps(vector)
i_vals = [0, 4, 5, 13]
N_val = 20
await store_parseremb.aadd_texts(
[_v_from_i(i, N_val) for i in i_vals],
metadatas=[{"i": i} for i in i_vals],
)
res1 = await store_parseremb.amax_marginal_relevance_search(
_v_from_i(3, N_val),
k=2,
fetch_k=3,
)
res_i_vals = {doc.metadata["i"] for doc in res1}
assert res_i_vals == {0, 4}
def test_astradb_vectorstore_metadata(
self, store_someemb: AstraDBVectorStore
) -> None:
"""Metadata filtering."""
store_someemb.add_documents(
[
Document(
page_content="q",
metadata={"ord": ord("q"), "group": "consonant"},
),
Document(
page_content="w",
metadata={"ord": ord("w"), "group": "consonant"},
),
Document(
page_content="r",
metadata={"ord": ord("r"), "group": "consonant"},
),
Document(
page_content="e",
metadata={"ord": ord("e"), "group": "vowel"},
),
Document(
page_content="i",
metadata={"ord": ord("i"), "group": "vowel"},
),
Document(
page_content="o",
metadata={"ord": ord("o"), "group": "vowel"},
),
]
)
# no filters
res0 = store_someemb.similarity_search("x", k=10)
assert {doc.page_content for doc in res0} == set("qwreio")
# single filter
res1 = store_someemb.similarity_search(
"x",
k=10,
filter={"group": "vowel"},
)
assert {doc.page_content for doc in res1} == set("eio")
# multiple filters
res2 = store_someemb.similarity_search(
"x",
k=10,
filter={"group": "consonant", "ord": ord("q")},
)
assert {doc.page_content for doc in res2} == set("q")
# excessive filters
res3 = store_someemb.similarity_search(
"x",
k=10,
filter={"group": "consonant", "ord": ord("q"), "case": "upper"},
)
assert res3 == []
# filter with logical operator
res4 = store_someemb.similarity_search(
"x",
k=10,
filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]},
)
assert {doc.page_content for doc in res4} == {"q", "r"}
def test_astradb_vectorstore_similarity_scale(
self, store_parseremb: AstraDBVectorStore
) -> None:
"""Scale of the similarity scores."""
store_parseremb.add_texts(
texts=[
json.dumps([1, 1]),
json.dumps([-1, -1]),
],
ids=["near", "far"],
)
res1 = store_parseremb.similarity_search_with_score(
json.dumps([0.5, 0.5]),
k=2,
)
scores = [sco for _, sco in res1]
sco_near, sco_far = scores
assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON
async def test_astradb_vectorstore_similarity_scale_async(
self, store_parseremb: AstraDBVectorStore
) -> None:
"""Scale of the similarity scores."""
await store_parseremb.aadd_texts(
texts=[
json.dumps([1, 1]),
json.dumps([-1, -1]),
],
ids=["near", "far"],
)
res1 = await store_parseremb.asimilarity_search_with_score(
json.dumps([0.5, 0.5]),
k=2,
)
scores = [sco for _, sco in res1]
sco_near, sco_far = scores
assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON
def test_astradb_vectorstore_massive_delete(
self, store_someemb: AstraDBVectorStore
) -> None:
"""Larger-scale bulk deletes."""
M = 50
texts = [str(i + 1 / 7.0) for i in range(2 * M)]
ids0 = ["doc_%i" % i for i in range(M)]
ids1 = ["doc_%i" % (i + M) for i in range(M)]
ids = ids0 + ids1
store_someemb.add_texts(texts=texts, ids=ids)
# deleting a bunch of these
del_res0 = store_someemb.delete(ids0)
assert del_res0 is True
# deleting the rest plus a fake one
del_res1 = store_someemb.delete(ids1 + ["ghost!"])
assert del_res1 is True # ensure no error
# nothing left
assert store_someemb.similarity_search("x", k=2 * M) == []
@pytest.mark.skipif(
SKIP_COLLECTION_DELETE,
reason="Collection-deletion tests are suppressed",
)
def test_astradb_vectorstore_delete_collection(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""behaviour of 'delete_collection'."""
collection_name = COLLECTION_NAME_DIM2
emb = SomeEmbeddings(dimension=2)
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=collection_name,
**astradb_credentials,
)
v_store.add_texts(["huh"])
assert len(v_store.similarity_search("hah", k=10)) == 1
# another instance pointing to the same collection on DB
v_store_kenny = AstraDBVectorStore(
embedding=emb,
collection_name=collection_name,
**astradb_credentials,
)
v_store_kenny.delete_collection()
# dropped on DB, but 'v_store' should have no clue:
with pytest.raises(ValueError):
_ = v_store.similarity_search("hah", k=10)
def test_astradb_vectorstore_custom_params(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""Custom batch size and concurrency params."""
emb = SomeEmbeddings(dimension=2)
# prepare empty collection
AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
).clear()
v_store = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
batch_size=17,
bulk_insert_batch_concurrency=13,
bulk_insert_overwrite_concurrency=7,
bulk_delete_concurrency=19,
)
try:
# add_texts
N = 50
texts = [str(i + 1 / 7.0) for i in range(N)]
ids = ["doc_%i" % i for i in range(N)]
v_store.add_texts(texts=texts, ids=ids)
v_store.add_texts(
texts=texts,
ids=ids,
batch_size=19,
batch_concurrency=7,
overwrite_concurrency=13,
)
#
_ = v_store.delete(ids[: N // 2])
_ = v_store.delete(ids[N // 2 :], concurrency=23)
#
finally:
if not SKIP_COLLECTION_DELETE:
v_store.delete_collection()
else:
v_store.clear()
async def test_astradb_vectorstore_custom_params_async(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""Custom batch size and concurrency params."""
emb = SomeEmbeddings(dimension=2)
v_store = AstraDBVectorStore(
embedding=emb,
collection_name="lc_test_c_async",
batch_size=17,
bulk_insert_batch_concurrency=13,
bulk_insert_overwrite_concurrency=7,
bulk_delete_concurrency=19,
**astradb_credentials,
)
try:
# add_texts
N = 50
texts = [str(i + 1 / 7.0) for i in range(N)]
ids = ["doc_%i" % i for i in range(N)]
await v_store.aadd_texts(texts=texts, ids=ids)
await v_store.aadd_texts(
texts=texts,
ids=ids,
batch_size=19,
batch_concurrency=7,
overwrite_concurrency=13,
)
#
await v_store.adelete(ids[: N // 2])
await v_store.adelete(ids[N // 2 :], concurrency=23)
#
finally:
if not SKIP_COLLECTION_DELETE:
await v_store.adelete_collection()
else:
await v_store.aclear()
def test_astradb_vectorstore_metrics(
self, astradb_credentials: AstraDBCredentials
) -> None:
"""
Different choices of similarity metric.
Both stores (with "cosine" and "euclidea" metrics) contain these two:
- a vector slightly rotated w.r.t query vector
- a vector which is a long multiple of query vector
so, which one is "the closest one" depends on the metric.
"""
emb = ParserEmbeddings(dimension=2)
isq2 = 0.5**0.5
isa = 0.7
isb = (1.0 - isa * isa) ** 0.5
texts = [
json.dumps([isa, isb]),
json.dumps([10 * isq2, 10 * isq2]),
]
ids = [
"rotated",
"scaled",
]
query_text = json.dumps([isq2, isq2])
# prepare empty collections
AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
**astradb_credentials,
).clear()
AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN,
metric="euclidean",
**astradb_credentials,
).clear()
# creation, population, query - cosine
vstore_cos = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2,
metric="cosine",
**astradb_credentials,
)
try:
vstore_cos.add_texts(
texts=texts,
ids=ids,
)
_, _, id_from_cos = vstore_cos.similarity_search_with_score_id(
query_text,
k=1,
)[0]
assert id_from_cos == "scaled"
finally:
if not SKIP_COLLECTION_DELETE:
vstore_cos.delete_collection()
else:
vstore_cos.clear()
# creation, population, query - euclidean
vstore_euc = AstraDBVectorStore(
embedding=emb,
collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN,
metric="euclidean",
**astradb_credentials,
)
try:
vstore_euc.add_texts(
texts=texts,
ids=ids,
)
_, _, id_from_euc = vstore_euc.similarity_search_with_score_id(
query_text,
k=1,
)[0]
assert id_from_euc == "rotated"
finally:
if not SKIP_COLLECTION_DELETE:
vstore_euc.delete_collection()
else:
vstore_euc.clear()

@ -1,12 +0,0 @@
from langchain_astradb import __all__
EXPECTED_ALL = [
"AstraDBByteStore",
"AstraDBStore",
"AstraDBChatMessageHistory",
"AstraDBVectorStore",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

@ -1,45 +0,0 @@
from typing import List
from unittest.mock import Mock
from langchain_core.embeddings import Embeddings
from langchain_astradb.vectorstores import AstraDBVectorStore
class SomeEmbeddings(Embeddings):
"""
Turn a sentence into an embedding vector in some way.
Not important how. It is deterministic is all that counts.
"""
def __init__(self, dimension: int) -> None:
self.dimension = dimension
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(txt) for txt in texts]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
unnormed0 = [ord(c) for c in text[: self.dimension]]
unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[
: self.dimension
]
norm = sum(x * x for x in unnormed) ** 0.5
normed = [x / norm for x in unnormed]
return normed
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
def test_initialization() -> None:
"""Test integration vectorstore initialization."""
mock_astra_db = Mock()
embedding = SomeEmbeddings(dimension=2)
AstraDBVectorStore(
embedding=embedding,
collection_name="mock_coll_name",
astra_db_client=mock_astra_db,
)
Loading…
Cancel
Save