mirror of https://github.com/hwchase17/langchain
astradb: move to langchain-datastax repo (#18354)
parent
b641be2edf
commit
6afb135baa
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,66 +0,0 @@
|
||||
SHELL := /bin/bash
|
||||
.PHONY: all format lint test tests integration_test integration_tests spell_check help
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
INTEGRATION_TEST_FILE ?= tests/integration_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
integration_test:
|
||||
poetry run pytest $(INTEGRATION_TEST_FILE)
|
||||
|
||||
integration_tests:
|
||||
poetry run pytest $(INTEGRATION_TEST_FILE)
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/astradb --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_astradb
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_astradb -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -1,68 +1,3 @@
|
||||
# langchain-astradb
|
||||
This package has moved!
|
||||
|
||||
This package contains the LangChain integrations for using DataStax Astra DB.
|
||||
|
||||
> DataStax [Astra DB](https://docs.datastax.com/en/astra/home/astra.html) is a serverless vector-capable database built on Apache Cassandra® and made conveniently available
|
||||
> through an easy-to-use JSON API.
|
||||
|
||||
_**Note.** For a short transitional period, only some of the Astra DB integration classes are contained in this package (the remaining ones being still in `langchain-community`). In a short while, and surely by version 0.2 of LangChain, all of the Astra DB support will be removed from `langchain-community` and included in this package._
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
Installation of this partner package:
|
||||
|
||||
```bash
|
||||
pip install langchain-astradb
|
||||
```
|
||||
|
||||
## Integrations overview
|
||||
|
||||
### Vector Store
|
||||
|
||||
```python
|
||||
from langchain_astradb import AstraDBVectorStore
|
||||
|
||||
my_store = AstraDBVectorStore(
|
||||
embedding=my_embeddings,
|
||||
collection_name="my_store",
|
||||
api_endpoint="https://...",
|
||||
token="AstraCS:...",
|
||||
)
|
||||
```
|
||||
|
||||
### Chat message history
|
||||
|
||||
```python
|
||||
from langchain_astradb import AstraDBChatMessageHistory
|
||||
message_history = AstraDBChatMessageHistory(
|
||||
session_id="test-session",
|
||||
api_endpoint="...",
|
||||
token="...",
|
||||
)
|
||||
```
|
||||
|
||||
### Store
|
||||
|
||||
```python
|
||||
from langchain_astradb import AstraDBStore
|
||||
store = AstraDBStore(
|
||||
collection_name="my_kv_store",
|
||||
api_endpoint="...",
|
||||
token="..."
|
||||
)
|
||||
```
|
||||
|
||||
### Byte Store
|
||||
|
||||
```python
|
||||
from langchain_astradb import AstraDBByteStore
|
||||
store = AstraDBByteStore(
|
||||
collection_name="my_kv_store",
|
||||
api_endpoint="...",
|
||||
token="..."
|
||||
)
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
See the [LangChain docs page](https://python.langchain.com/docs/integrations/providers/astradb) for a more detailed listing.
|
||||
https://github.com/langchain-ai/langchain-datastax/tree/main/libs/astradb
|
@ -1,10 +0,0 @@
|
||||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
|
||||
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
|
||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
||||
|
||||
__all__ = [
|
||||
"AstraDBByteStore",
|
||||
"AstraDBStore",
|
||||
"AstraDBChatMessageHistory",
|
||||
"AstraDBVectorStore",
|
||||
]
|
@ -1,148 +0,0 @@
|
||||
"""Astra DB - based chat message history, based on astrapy."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
from langchain_astradb.utils.astradb import (
|
||||
SetupMode,
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
|
||||
DEFAULT_COLLECTION_NAME = "langchain_message_store"
|
||||
|
||||
|
||||
class AstraDBChatMessageHistory(BaseChatMessageHistory):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
) -> None:
|
||||
"""Chat message history that stores history in Astra DB.
|
||||
|
||||
Args:
|
||||
session_id: arbitrary key that is used to store the messages
|
||||
of a single chat session.
|
||||
collection_name: name of the Astra DB collection to create/use.
|
||||
token: API token for Astra DB usage.
|
||||
api_endpoint: full URL to the API endpoint,
|
||||
such as "https://<DB-ID>-us-east1.apps.astra.datastax.com".
|
||||
astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
||||
namespace: namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
"""
|
||||
self.astra_env = _AstraDBCollectionEnvironment(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
|
||||
self.collection = self.astra_env.collection
|
||||
self.async_collection = self.astra_env.async_collection
|
||||
|
||||
self.session_id = session_id
|
||||
self.collection_name = collection_name
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]:
|
||||
"""Retrieve all session messages from DB"""
|
||||
self.astra_env.ensure_db_setup()
|
||||
message_blobs = [
|
||||
doc["body_blob"]
|
||||
for doc in sorted(
|
||||
self.collection.paginated_find(
|
||||
filter={
|
||||
"session_id": self.session_id,
|
||||
},
|
||||
projection={
|
||||
"timestamp": 1,
|
||||
"body_blob": 1,
|
||||
},
|
||||
),
|
||||
key=lambda _doc: _doc["timestamp"],
|
||||
)
|
||||
]
|
||||
items = [json.loads(message_blob) for message_blob in message_blobs]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages: List[BaseMessage]) -> None:
|
||||
raise NotImplementedError("Use add_messages instead")
|
||||
|
||||
async def aget_messages(self) -> List[BaseMessage]:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
docs = self.async_collection.paginated_find(
|
||||
filter={
|
||||
"session_id": self.session_id,
|
||||
},
|
||||
projection={
|
||||
"timestamp": 1,
|
||||
"body_blob": 1,
|
||||
},
|
||||
)
|
||||
sorted_docs = sorted(
|
||||
[doc async for doc in docs],
|
||||
key=lambda _doc: _doc["timestamp"],
|
||||
)
|
||||
message_blobs = [doc["body_blob"] for doc in sorted_docs]
|
||||
items = [json.loads(message_blob) for message_blob in message_blobs]
|
||||
messages = messages_from_dict(items)
|
||||
return messages
|
||||
|
||||
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
self.astra_env.ensure_db_setup()
|
||||
docs = [
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"session_id": self.session_id,
|
||||
"body_blob": json.dumps(message_to_dict(message)),
|
||||
}
|
||||
for message in messages
|
||||
]
|
||||
self.collection.chunked_insert_many(docs)
|
||||
|
||||
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
docs = [
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"session_id": self.session_id,
|
||||
"body_blob": json.dumps(message_to_dict(message)),
|
||||
}
|
||||
for message in messages
|
||||
]
|
||||
await self.async_collection.chunked_insert_many(docs)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.astra_env.ensure_db_setup()
|
||||
self.collection.delete_many(filter={"session_id": self.session_id})
|
||||
|
||||
async def aclear(self) -> None:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
await self.async_collection.delete_many(filter={"session_id": self.session_id})
|
@ -1,217 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
|
||||
from langchain_astradb.utils.astradb import (
|
||||
SetupMode,
|
||||
_AstraDBCollectionEnvironment,
|
||||
)
|
||||
|
||||
V = TypeVar("V")
|
||||
|
||||
|
||||
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
|
||||
"""Base class for the DataStax AstraDB data store."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
|
||||
self.collection = self.astra_env.collection
|
||||
self.async_collection = self.astra_env.async_collection
|
||||
|
||||
@abstractmethod
|
||||
def decode_value(self, value: Any) -> Optional[V]:
|
||||
"""Decodes value from Astra DB"""
|
||||
|
||||
@abstractmethod
|
||||
def encode_value(self, value: Optional[V]) -> Any:
|
||||
"""Encodes value for Astra DB"""
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
self.astra_env.ensure_db_setup()
|
||||
docs_dict = {}
|
||||
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
|
||||
docs_dict[doc["_id"]] = doc.get("value")
|
||||
return [self.decode_value(docs_dict.get(key)) for key in keys]
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
docs_dict = {}
|
||||
async for doc in self.async_collection.paginated_find(
|
||||
filter={"_id": {"$in": list(keys)}}
|
||||
):
|
||||
docs_dict[doc["_id"]] = doc.get("value")
|
||||
return [self.decode_value(docs_dict.get(key)) for key in keys]
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
self.astra_env.ensure_db_setup()
|
||||
for k, v in key_value_pairs:
|
||||
self.collection.upsert_one({"_id": k, "value": self.encode_value(v)})
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
for k, v in key_value_pairs:
|
||||
await self.async_collection.upsert_one(
|
||||
{"_id": k, "value": self.encode_value(v)}
|
||||
)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
self.astra_env.ensure_db_setup()
|
||||
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
|
||||
|
||||
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
self.astra_env.ensure_db_setup()
|
||||
docs = self.collection.paginated_find()
|
||||
for doc in docs:
|
||||
key = doc["_id"]
|
||||
if not prefix or key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||
await self.astra_env.aensure_db_setup()
|
||||
async for doc in self.async_collection.paginated_find():
|
||||
key = doc["_id"]
|
||||
if not prefix or key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
|
||||
class AstraDBStore(AstraDBBaseStore[Any]):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
*,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
) -> None:
|
||||
"""BaseStore implementation using DataStax AstraDB as the underlying store.
|
||||
|
||||
The value type can be any type serializable by json.dumps.
|
||||
Can be used to store embeddings with the CacheBackedEmbeddings.
|
||||
|
||||
Documents in the AstraDB collection will have the format
|
||||
|
||||
.. code-block:: json
|
||||
{
|
||||
"_id": "<key>",
|
||||
"value": <value>
|
||||
}
|
||||
|
||||
Args:
|
||||
collection_name: name of the Astra DB collection to create/use.
|
||||
token: API token for Astra DB usage.
|
||||
api_endpoint: full URL to the API endpoint,
|
||||
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
|
||||
astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
||||
namespace: namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
|
||||
OFF).
|
||||
pre_delete_collection: whether to delete the collection
|
||||
before creating it. If False and the collection already exists,
|
||||
the collection will be used as is.
|
||||
"""
|
||||
super().__init__(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
|
||||
def decode_value(self, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
def encode_value(self, value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
pre_delete_collection: bool = False,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
) -> None:
|
||||
"""ByteStore implementation using DataStax AstraDB as the underlying store.
|
||||
|
||||
The bytes values are converted to base64 encoded strings
|
||||
Documents in the AstraDB collection will have the format
|
||||
|
||||
.. code-block:: json
|
||||
{
|
||||
"_id": "<key>",
|
||||
"value": "<byte64 string value>"
|
||||
}
|
||||
|
||||
Args:
|
||||
collection_name: name of the Astra DB collection to create/use.
|
||||
token: API token for Astra DB usage.
|
||||
api_endpoint: full URL to the API endpoint,
|
||||
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
|
||||
astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AstraDB' instance.
|
||||
async_astra_db_client: *alternative to token+api_endpoint*,
|
||||
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
|
||||
namespace: namespace (aka keyspace) where the
|
||||
collection is created. Defaults to the database's "default namespace".
|
||||
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
|
||||
OFF).
|
||||
pre_delete_collection: whether to delete the collection
|
||||
before creating it. If False and the collection already exists,
|
||||
the collection will be used as is.
|
||||
"""
|
||||
super().__init__(
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
astra_db_client=astra_db_client,
|
||||
async_astra_db_client=async_astra_db_client,
|
||||
namespace=namespace,
|
||||
setup_mode=setup_mode,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
)
|
||||
|
||||
def decode_value(self, value: Any) -> Optional[bytes]:
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64decode(value)
|
||||
|
||||
def encode_value(self, value: Optional[bytes]) -> Any:
|
||||
if value is None:
|
||||
return None
|
||||
return base64.b64encode(value).decode("ascii")
|
@ -1,152 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from asyncio import InvalidStateError, Task
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Optional, Union
|
||||
|
||||
import langchain_core
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
|
||||
|
||||
class SetupMode(Enum):
|
||||
SYNC = 1
|
||||
ASYNC = 2
|
||||
OFF = 3
|
||||
|
||||
|
||||
class _AstraDBEnvironment:
|
||||
def __init__(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
) -> None:
|
||||
self.token = token
|
||||
self.api_endpoint = api_endpoint
|
||||
astra_db = astra_db_client
|
||||
async_astra_db = async_astra_db_client
|
||||
self.namespace = namespace
|
||||
|
||||
# Conflicting-arg checks:
|
||||
if astra_db_client is not None or async_astra_db_client is not None:
|
||||
if token is not None or api_endpoint is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass 'astra_db_client' or 'async_astra_db_client' to "
|
||||
"AstraDBEnvironment if passing 'token' and 'api_endpoint'."
|
||||
)
|
||||
|
||||
if token and api_endpoint:
|
||||
astra_db = AstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
async_astra_db = AsyncAstraDB(
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
|
||||
if astra_db:
|
||||
self.astra_db = astra_db.copy()
|
||||
if async_astra_db:
|
||||
self.async_astra_db = async_astra_db.copy()
|
||||
else:
|
||||
self.async_astra_db = self.astra_db.to_async()
|
||||
elif async_astra_db:
|
||||
self.async_astra_db = async_astra_db.copy()
|
||||
self.astra_db = self.async_astra_db.to_sync()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Must provide 'astra_db_client' or 'async_astra_db_client' or "
|
||||
"'token' and 'api_endpoint'"
|
||||
)
|
||||
|
||||
self.astra_db.set_caller(
|
||||
caller_name="langchain",
|
||||
caller_version=getattr(langchain_core, "__version__", None),
|
||||
)
|
||||
self.async_astra_db.set_caller(
|
||||
caller_name="langchain",
|
||||
caller_version=getattr(langchain_core, "__version__", None),
|
||||
)
|
||||
|
||||
|
||||
class _AstraDBCollectionEnvironment(_AstraDBEnvironment):
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
token: Optional[str] = None,
|
||||
api_endpoint: Optional[str] = None,
|
||||
astra_db_client: Optional[AstraDB] = None,
|
||||
async_astra_db_client: Optional[AsyncAstraDB] = None,
|
||||
namespace: Optional[str] = None,
|
||||
setup_mode: SetupMode = SetupMode.SYNC,
|
||||
pre_delete_collection: bool = False,
|
||||
embedding_dimension: Union[int, Awaitable[int], None] = None,
|
||||
metric: Optional[str] = None,
|
||||
) -> None:
|
||||
from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
|
||||
|
||||
super().__init__(
|
||||
token, api_endpoint, astra_db_client, async_astra_db_client, namespace
|
||||
)
|
||||
self.collection_name = collection_name
|
||||
self.collection = AstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.astra_db,
|
||||
)
|
||||
|
||||
self.async_collection = AsyncAstraDBCollection(
|
||||
collection_name=collection_name,
|
||||
astra_db=self.async_astra_db,
|
||||
)
|
||||
|
||||
self.async_setup_db_task: Optional[Task] = None
|
||||
if setup_mode == SetupMode.ASYNC:
|
||||
async_astra_db = self.async_astra_db
|
||||
|
||||
async def _setup_db() -> None:
|
||||
if pre_delete_collection:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
dimension = await embedding_dimension
|
||||
else:
|
||||
dimension = embedding_dimension
|
||||
await async_astra_db.create_collection(
|
||||
collection_name, dimension=dimension, metric=metric
|
||||
)
|
||||
|
||||
self.async_setup_db_task = asyncio.create_task(_setup_db())
|
||||
elif setup_mode == SetupMode.SYNC:
|
||||
if pre_delete_collection:
|
||||
self.astra_db.delete_collection(collection_name)
|
||||
if inspect.isawaitable(embedding_dimension):
|
||||
raise ValueError(
|
||||
"Cannot use an awaitable embedding_dimension with async_setup "
|
||||
"set to False"
|
||||
)
|
||||
self.astra_db.create_collection(
|
||||
collection_name,
|
||||
dimension=embedding_dimension, # type: ignore[arg-type]
|
||||
metric=metric,
|
||||
)
|
||||
|
||||
def ensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
try:
|
||||
self.async_setup_db_task.result()
|
||||
except InvalidStateError:
|
||||
raise ValueError(
|
||||
"Asynchronous setup of the DB not finished. "
|
||||
"NB: AstraDB components sync methods shouldn't be called from the "
|
||||
"event loop. Consider using their async equivalents."
|
||||
)
|
||||
|
||||
async def aensure_db_setup(self) -> None:
|
||||
if self.async_setup_db_task:
|
||||
await self.async_setup_db_task
|
@ -1,87 +0,0 @@
|
||||
"""
|
||||
Tools for the Maximal Marginal Relevance (MMR) reranking.
|
||||
Duplicated from langchain_community to avoid cross-dependencies.
|
||||
|
||||
Functions "maximal_marginal_relevance" and "cosine_similarity"
|
||||
are duplicated in this utility respectively from modules:
|
||||
- "libs/community/langchain_community/vectorstores/utils.py"
|
||||
- "libs/community/langchain_community/utils/math.py"
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
|
||||
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices."""
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return np.array([])
|
||||
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
||||
f"and Y has shape {Y.shape}."
|
||||
)
|
||||
try:
|
||||
import simsimd as simd # type: ignore
|
||||
|
||||
X = np.array(X, dtype=np.float32)
|
||||
Y = np.array(Y, dtype=np.float32)
|
||||
Z = 1 - simd.cdist(X, Y, metric="cosine")
|
||||
if isinstance(Z, float):
|
||||
return np.array([Z])
|
||||
return Z
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Unable to import simsimd, defaulting to NumPy implementation. If you want "
|
||||
"to use simsimd please install with `pip install simsimd`."
|
||||
)
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
query_embedding: np.ndarray,
|
||||
embedding_list: list,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[int]:
|
||||
"""Calculate maximal marginal relevance."""
|
||||
if min(k, len(embedding_list)) <= 0:
|
||||
return []
|
||||
if query_embedding.ndim == 1:
|
||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
||||
most_similar = int(np.argmax(similarity_to_query))
|
||||
idxs = [most_similar]
|
||||
selected = np.array([embedding_list[most_similar]])
|
||||
while len(idxs) < min(k, len(embedding_list)):
|
||||
best_score = -np.inf
|
||||
idx_to_add = -1
|
||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
||||
for i, query_score in enumerate(similarity_to_query):
|
||||
if i in idxs:
|
||||
continue
|
||||
redundant_score = max(similarity_to_selected[i])
|
||||
equation_score = (
|
||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
||||
)
|
||||
if equation_score > best_score:
|
||||
best_score = equation_score
|
||||
idx_to_add = i
|
||||
idxs.append(idx_to_add)
|
||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||
return idxs
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,92 +0,0 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-astradb"
|
||||
version = "0.0.1"
|
||||
description = "An integration package connecting Astra DB and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.5"
|
||||
astrapy = "^0.7.5"
|
||||
numpy = "^1"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
pytest-dotenv = "^0.5.2"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
langchain = { path = "../../langchain", develop = true }
|
||||
langchain-community = { path = "../../community", develop = true }
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
@ -1,17 +0,0 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -1,5 +0,0 @@
|
||||
# astra db
|
||||
ASTRA_DB_API_ENDPOINT=https://your_astra_db_id-your_region.apps.astra.datastax.com
|
||||
ASTRA_DB_APPLICATION_TOKEN=AstraCS:your_astra_db_application_token
|
||||
# ASTRA_DB_KEYSPACE=your_astra_db_namespace
|
||||
# ASTRA_DB_SKIP_COLLECTION_DELETIONS=true
|
@ -1,19 +0,0 @@
|
||||
# Getting the absolute path of the current file's directory
|
||||
import os
|
||||
|
||||
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Getting the absolute path of the project's root directory
|
||||
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
|
||||
|
||||
|
||||
# Loading the .env file if it exists
|
||||
def _load_env() -> None:
|
||||
dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env")
|
||||
if os.path.exists(dotenv_path):
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(dotenv_path)
|
||||
|
||||
|
||||
_load_env()
|
@ -1,198 +0,0 @@
|
||||
import os
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
import pytest
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_astradb.chat_message_histories import (
|
||||
AstraDBChatMessageHistory,
|
||||
)
|
||||
from langchain_astradb.utils.astradb import SetupMode
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def history1() -> Iterable[AstraDBChatMessageHistory]:
|
||||
history1 = AstraDBChatMessageHistory(
|
||||
session_id="session-test-1",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield history1
|
||||
history1.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def history2() -> Iterable[AstraDBChatMessageHistory]:
|
||||
history2 = AstraDBChatMessageHistory(
|
||||
session_id="session-test-2",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
yield history2
|
||||
history2.collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_history1() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||
history1 = AstraDBChatMessageHistory(
|
||||
session_id="async-session-test-1",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield history1
|
||||
await history1.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def async_history2() -> AsyncIterable[AstraDBChatMessageHistory]:
|
||||
history2 = AstraDBChatMessageHistory(
|
||||
session_id="async-session-test-2",
|
||||
collection_name="langchain_cmh_test",
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
yield history2
|
||||
await history2.async_collection.astra_db.delete_collection("langchain_cmh_test")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
def test_memory_with_message_store(history1: AstraDBChatMessageHistory) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz",
|
||||
chat_memory=history1,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
# add some messages
|
||||
memory.chat_memory.add_messages(
|
||||
[
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
)
|
||||
|
||||
messages = memory.chat_memory.messages
|
||||
expected = [
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
# clear the store
|
||||
memory.chat_memory.clear()
|
||||
|
||||
assert memory.chat_memory.messages == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
async def test_memory_with_message_store_async(
|
||||
async_history1: AstraDBChatMessageHistory,
|
||||
) -> None:
|
||||
"""Test the memory with a message store."""
|
||||
memory = ConversationBufferMemory(
|
||||
memory_key="baz",
|
||||
chat_memory=async_history1,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
assert await memory.chat_memory.aget_messages() == []
|
||||
|
||||
# add some messages
|
||||
await memory.chat_memory.aadd_messages(
|
||||
[
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
)
|
||||
|
||||
messages = await memory.chat_memory.aget_messages()
|
||||
expected = [
|
||||
AIMessage(content="This is me, the AI"),
|
||||
HumanMessage(content="This is me, the human"),
|
||||
]
|
||||
assert messages == expected
|
||||
|
||||
# clear the store
|
||||
await memory.chat_memory.aclear()
|
||||
|
||||
assert await memory.chat_memory.aget_messages() == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
def test_memory_separate_session_ids(
|
||||
history1: AstraDBChatMessageHistory, history2: AstraDBChatMessageHistory
|
||||
) -> None:
|
||||
"""Test that separate session IDs do not share entries."""
|
||||
memory1 = ConversationBufferMemory(
|
||||
memory_key="mk1",
|
||||
chat_memory=history1,
|
||||
return_messages=True,
|
||||
)
|
||||
memory2 = ConversationBufferMemory(
|
||||
memory_key="mk2",
|
||||
chat_memory=history2,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
memory1.chat_memory.add_messages([AIMessage(content="Just saying.")])
|
||||
|
||||
assert memory2.chat_memory.messages == []
|
||||
|
||||
memory2.chat_memory.clear()
|
||||
|
||||
assert memory1.chat_memory.messages != []
|
||||
|
||||
memory1.chat_memory.clear()
|
||||
|
||||
assert memory1.chat_memory.messages == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
async def test_memory_separate_session_ids_async(
|
||||
async_history1: AstraDBChatMessageHistory, async_history2: AstraDBChatMessageHistory
|
||||
) -> None:
|
||||
"""Test that separate session IDs do not share entries."""
|
||||
memory1 = ConversationBufferMemory(
|
||||
memory_key="mk1",
|
||||
chat_memory=async_history1,
|
||||
return_messages=True,
|
||||
)
|
||||
memory2 = ConversationBufferMemory(
|
||||
memory_key="mk2",
|
||||
chat_memory=async_history2,
|
||||
return_messages=True,
|
||||
)
|
||||
|
||||
await memory1.chat_memory.aadd_messages([AIMessage(content="Just saying.")])
|
||||
|
||||
assert await memory2.chat_memory.aget_messages() == []
|
||||
|
||||
await memory2.chat_memory.aclear()
|
||||
|
||||
assert await memory1.chat_memory.aget_messages() != []
|
||||
|
||||
await memory1.chat_memory.aclear()
|
||||
|
||||
assert await memory1.chat_memory.aget_messages() == []
|
@ -1,7 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -1,176 +0,0 @@
|
||||
"""Implement integration tests for AstraDB storage."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from astrapy.db import AstraDB, AsyncAstraDB
|
||||
|
||||
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
|
||||
from langchain_astradb.utils.astradb import SetupMode
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def astra_db() -> AstraDB:
|
||||
from astrapy.db import AstraDB
|
||||
|
||||
return AstraDB(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_astra_db() -> AsyncAstraDB:
|
||||
from astrapy.db import AsyncAstraDB
|
||||
|
||||
return AsyncAstraDB(
|
||||
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
namespace=os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
)
|
||||
|
||||
|
||||
def init_store(astra_db: AstraDB, collection_name: str) -> AstraDBStore:
|
||||
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
|
||||
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
|
||||
return store
|
||||
|
||||
|
||||
def init_bytestore(astra_db: AstraDB, collection_name: str) -> AstraDBByteStore:
|
||||
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
|
||||
store.mset([("key1", b"value1"), ("key2", b"value2")])
|
||||
return store
|
||||
|
||||
|
||||
async def init_async_store(
|
||||
async_astra_db: AsyncAstraDB, collection_name: str
|
||||
) -> AstraDBStore:
|
||||
store = AstraDBStore(
|
||||
collection_name=collection_name,
|
||||
async_astra_db_client=async_astra_db,
|
||||
setup_mode=SetupMode.ASYNC,
|
||||
)
|
||||
await store.amset([("key1", [0.1, 0.2]), ("key2", "value2")])
|
||||
return store
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDBStore:
|
||||
def test_mget(self, astra_db: AstraDB) -> None:
|
||||
"""Test AstraDBStore mget method."""
|
||||
collection_name = "lc_test_store_mget"
|
||||
try:
|
||||
store = init_store(astra_db, collection_name)
|
||||
assert store.mget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
async def test_amget(self, async_astra_db: AsyncAstraDB) -> None:
|
||||
"""Test AstraDBStore amget method."""
|
||||
collection_name = "lc_test_store_mget"
|
||||
try:
|
||||
store = await init_async_store(async_astra_db, collection_name)
|
||||
assert await store.amget(["key1", "key2"]) == [[0.1, 0.2], "value2"]
|
||||
finally:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_mset(self, astra_db: AstraDB) -> None:
|
||||
"""Test that multiple keys can be set with AstraDBStore."""
|
||||
collection_name = "lc_test_store_mset"
|
||||
try:
|
||||
store = init_store(astra_db, collection_name)
|
||||
result = store.collection.find_one({"_id": "key1"})
|
||||
assert result["data"]["document"]["value"] == [0.1, 0.2]
|
||||
result = store.collection.find_one({"_id": "key2"})
|
||||
assert result["data"]["document"]["value"] == "value2"
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
async def test_amset(self, async_astra_db: AsyncAstraDB) -> None:
|
||||
"""Test that multiple keys can be set with AstraDBStore."""
|
||||
collection_name = "lc_test_store_mset"
|
||||
try:
|
||||
store = await init_async_store(async_astra_db, collection_name)
|
||||
result = await store.async_collection.find_one({"_id": "key1"})
|
||||
assert result["data"]["document"]["value"] == [0.1, 0.2]
|
||||
result = await store.async_collection.find_one({"_id": "key2"})
|
||||
assert result["data"]["document"]["value"] == "value2"
|
||||
finally:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_mdelete(self, astra_db: AstraDB) -> None:
|
||||
"""Test that deletion works as expected."""
|
||||
collection_name = "lc_test_store_mdelete"
|
||||
try:
|
||||
store = init_store(astra_db, collection_name)
|
||||
store.mdelete(["key1", "key2"])
|
||||
result = store.mget(["key1", "key2"])
|
||||
assert result == [None, None]
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
async def test_amdelete(self, async_astra_db: AsyncAstraDB) -> None:
|
||||
"""Test that deletion works as expected."""
|
||||
collection_name = "lc_test_store_mdelete"
|
||||
try:
|
||||
store = await init_async_store(async_astra_db, collection_name)
|
||||
await store.amdelete(["key1", "key2"])
|
||||
result = await store.amget(["key1", "key2"])
|
||||
assert result == [None, None]
|
||||
finally:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_yield_keys(self, astra_db: AstraDB) -> None:
|
||||
collection_name = "lc_test_store_yield_keys"
|
||||
try:
|
||||
store = init_store(astra_db, collection_name)
|
||||
assert set(store.yield_keys()) == {"key1", "key2"}
|
||||
assert set(store.yield_keys(prefix="key")) == {"key1", "key2"}
|
||||
assert set(store.yield_keys(prefix="lang")) == set()
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
async def test_ayield_keys(self, async_astra_db: AsyncAstraDB) -> None:
|
||||
collection_name = "lc_test_store_yield_keys"
|
||||
try:
|
||||
store = await init_async_store(async_astra_db, collection_name)
|
||||
assert {key async for key in store.ayield_keys()} == {"key1", "key2"}
|
||||
assert {key async for key in store.ayield_keys(prefix="key")} == {
|
||||
"key1",
|
||||
"key2",
|
||||
}
|
||||
assert {key async for key in store.ayield_keys(prefix="lang")} == set()
|
||||
finally:
|
||||
await async_astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_bytestore_mget(self, astra_db: AstraDB) -> None:
|
||||
"""Test AstraDBByteStore mget method."""
|
||||
collection_name = "lc_test_bytestore_mget"
|
||||
try:
|
||||
store = init_bytestore(astra_db, collection_name)
|
||||
assert store.mget(["key1", "key2"]) == [b"value1", b"value2"]
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
||||
|
||||
def test_bytestore_mset(self, astra_db: AstraDB) -> None:
|
||||
"""Test that multiple keys can be set with AstraDBByteStore."""
|
||||
collection_name = "lc_test_bytestore_mset"
|
||||
try:
|
||||
store = init_bytestore(astra_db, collection_name)
|
||||
result = store.collection.find_one({"_id": "key1"})
|
||||
assert result["data"]["document"]["value"] == "dmFsdWUx"
|
||||
result = store.collection.find_one({"_id": "key2"})
|
||||
assert result["data"]["document"]["value"] == "dmFsdWUy"
|
||||
finally:
|
||||
astra_db.delete_collection(collection_name)
|
@ -1,868 +0,0 @@
|
||||
"""
|
||||
Test of Astra DB vector store class `AstraDBVectorStore`
|
||||
|
||||
Required to run this test:
|
||||
- a recent `astrapy` Python package available
|
||||
- an Astra DB instance;
|
||||
- the two environment variables set:
|
||||
export ASTRA_DB_API_ENDPOINT="https://<DB-ID>-us-east1.apps.astra.datastax.com"
|
||||
export ASTRA_DB_APPLICATION_TOKEN="AstraCS:........."
|
||||
- optionally this as well (otherwise defaults are used):
|
||||
export ASTRA_DB_KEYSPACE="my_keyspace"
|
||||
- optionally:
|
||||
export ASTRA_DB_SKIP_COLLECTION_DELETIONS="0" ("1" = no deletions, default)
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Iterable, List, Optional, TypedDict
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
||||
|
||||
# Faster testing (no actual collection deletions). Off by default (=full tests)
|
||||
SKIP_COLLECTION_DELETE = (
|
||||
int(os.environ.get("ASTRA_DB_SKIP_COLLECTION_DELETIONS", "0")) != 0
|
||||
)
|
||||
|
||||
COLLECTION_NAME_DIM2 = "lc_test_d2"
|
||||
COLLECTION_NAME_DIM2_EUCLIDEAN = "lc_test_d2_eucl"
|
||||
|
||||
MATCH_EPSILON = 0.0001
|
||||
|
||||
# Ad-hoc embedding classes:
|
||||
|
||||
|
||||
class AstraDBCredentials(TypedDict):
|
||||
token: str
|
||||
api_endpoint: str
|
||||
namespace: Optional[str]
|
||||
|
||||
|
||||
class SomeEmbeddings(Embeddings):
|
||||
"""
|
||||
Turn a sentence into an embedding vector in some way.
|
||||
Not important how. It is deterministic is all that counts.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
unnormed0 = [ord(c) for c in text[: self.dimension]]
|
||||
unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[
|
||||
: self.dimension
|
||||
]
|
||||
norm = sum(x * x for x in unnormed) ** 0.5
|
||||
normed = [x / norm for x in unnormed]
|
||||
return normed
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
class ParserEmbeddings(Embeddings):
|
||||
"""
|
||||
Parse input texts: if they are json for a List[float], fine.
|
||||
Otherwise, return all zeros and call it a day.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
try:
|
||||
vals = json.loads(text)
|
||||
assert len(vals) == self.dimension
|
||||
return vals
|
||||
except Exception:
|
||||
print(f'[ParserEmbeddings] Returning a moot vector for "{text}"')
|
||||
return [0.0] * self.dimension
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
def _has_env_vars() -> bool:
|
||||
return all(
|
||||
[
|
||||
"ASTRA_DB_APPLICATION_TOKEN" in os.environ,
|
||||
"ASTRA_DB_API_ENDPOINT" in os.environ,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def astradb_credentials() -> Iterable[AstraDBCredentials]:
|
||||
yield {
|
||||
"token": os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
||||
"api_endpoint": os.environ["ASTRA_DB_API_ENDPOINT"],
|
||||
"namespace": os.environ.get("ASTRA_DB_KEYSPACE"),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def store_someemb(
|
||||
astradb_credentials: AstraDBCredentials,
|
||||
) -> Iterable[AstraDBVectorStore]:
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store.clear()
|
||||
|
||||
yield v_store
|
||||
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
v_store.delete_collection()
|
||||
else:
|
||||
v_store.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def store_parseremb(
|
||||
astradb_credentials: AstraDBCredentials,
|
||||
) -> Iterable[AstraDBVectorStore]:
|
||||
emb = ParserEmbeddings(dimension=2)
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store.clear()
|
||||
|
||||
yield v_store
|
||||
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
v_store.delete_collection()
|
||||
else:
|
||||
v_store.clear()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
|
||||
class TestAstraDBVectorStore:
|
||||
def test_astradb_vectorstore_create_delete(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""Create and delete."""
|
||||
from astrapy.db import AstraDB as LibAstraDB
|
||||
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# creation by passing the connection secrets
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store.add_texts("Sample 1")
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
v_store.delete_collection()
|
||||
else:
|
||||
v_store.clear()
|
||||
|
||||
# Creation by passing a ready-made astrapy client:
|
||||
astra_db_client = LibAstraDB(
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store_2 = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
astra_db_client=astra_db_client,
|
||||
)
|
||||
v_store_2.add_texts("Sample 2")
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
v_store_2.delete_collection()
|
||||
else:
|
||||
v_store_2.clear()
|
||||
|
||||
async def test_astradb_vectorstore_create_delete_async(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""Create and delete."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# creation by passing the connection secrets
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
await v_store.adelete_collection()
|
||||
# Creation by passing a ready-made astrapy client:
|
||||
from astrapy.db import AsyncAstraDB
|
||||
|
||||
astra_db_client = AsyncAstraDB(
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store_2 = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_2_async",
|
||||
async_astra_db_client=astra_db_client,
|
||||
)
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
await v_store_2.adelete_collection()
|
||||
else:
|
||||
await v_store_2.aclear()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
SKIP_COLLECTION_DELETE,
|
||||
reason="Collection-deletion tests are suppressed",
|
||||
)
|
||||
def test_astradb_vectorstore_pre_delete_collection(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""Use of the pre_delete_collection flag."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store.clear()
|
||||
try:
|
||||
v_store.add_texts(
|
||||
texts=["aa"],
|
||||
metadatas=[
|
||||
{"k": "a", "ord": 0},
|
||||
],
|
||||
ids=["a"],
|
||||
)
|
||||
res1 = v_store.similarity_search("aa", k=5)
|
||||
assert len(res1) == 1
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
pre_delete_collection=True,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
res1 = v_store.similarity_search("aa", k=5)
|
||||
assert len(res1) == 0
|
||||
finally:
|
||||
v_store.delete_collection()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
SKIP_COLLECTION_DELETE,
|
||||
reason="Collection-deletion tests are suppressed",
|
||||
)
|
||||
async def test_astradb_vectorstore_pre_delete_collection_async(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""Use of the pre_delete_collection flag."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# creation by passing the connection secrets
|
||||
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
await v_store.aadd_texts(
|
||||
texts=["aa"],
|
||||
metadatas=[
|
||||
{"k": "a", "ord": 0},
|
||||
],
|
||||
ids=["a"],
|
||||
)
|
||||
res1 = await v_store.asimilarity_search("aa", k=5)
|
||||
assert len(res1) == 1
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
pre_delete_collection=True,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
res1 = await v_store.asimilarity_search("aa", k=5)
|
||||
assert len(res1) == 0
|
||||
finally:
|
||||
await v_store.adelete_collection()
|
||||
|
||||
def test_astradb_vectorstore_from_x(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""from_texts and from_documents methods."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# prepare empty collection
|
||||
AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
).clear()
|
||||
# from_texts
|
||||
v_store = AstraDBVectorStore.from_texts(
|
||||
texts=["Hi", "Ho"],
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
assert v_store.similarity_search("Ho", k=1)[0].page_content == "Ho"
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
v_store.delete_collection()
|
||||
else:
|
||||
v_store.clear()
|
||||
|
||||
# from_documents
|
||||
v_store_2 = AstraDBVectorStore.from_documents(
|
||||
[
|
||||
Document(page_content="Hee"),
|
||||
Document(page_content="Hoi"),
|
||||
],
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
assert v_store_2.similarity_search("Hoi", k=1)[0].page_content == "Hoi"
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
v_store_2.delete_collection()
|
||||
else:
|
||||
v_store_2.clear()
|
||||
|
||||
async def test_astradb_vectorstore_from_x_async(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""from_texts and from_documents methods."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# prepare empty collection
|
||||
await AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
).aclear()
|
||||
# from_texts
|
||||
v_store = await AstraDBVectorStore.afrom_texts(
|
||||
texts=["Hi", "Ho"],
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
assert (await v_store.asimilarity_search("Ho", k=1))[0].page_content == "Ho"
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
await v_store.adelete_collection()
|
||||
else:
|
||||
await v_store.aclear()
|
||||
|
||||
# from_documents
|
||||
v_store_2 = await AstraDBVectorStore.afrom_documents(
|
||||
[
|
||||
Document(page_content="Hee"),
|
||||
Document(page_content="Hoi"),
|
||||
],
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
assert (await v_store_2.asimilarity_search("Hoi", k=1))[
|
||||
0
|
||||
].page_content == "Hoi"
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
await v_store_2.adelete_collection()
|
||||
else:
|
||||
await v_store_2.aclear()
|
||||
|
||||
def test_astradb_vectorstore_crud(self, store_someemb: AstraDBVectorStore) -> None:
|
||||
"""Basic add/delete/update behaviour."""
|
||||
res0 = store_someemb.similarity_search("Abc", k=2)
|
||||
assert res0 == []
|
||||
# write and check again
|
||||
store_someemb.add_texts(
|
||||
texts=["aa", "bb", "cc"],
|
||||
metadatas=[
|
||||
{"k": "a", "ord": 0},
|
||||
{"k": "b", "ord": 1},
|
||||
{"k": "c", "ord": 2},
|
||||
],
|
||||
ids=["a", "b", "c"],
|
||||
)
|
||||
res1 = store_someemb.similarity_search("Abc", k=5)
|
||||
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
|
||||
# partial overwrite and count total entries
|
||||
store_someemb.add_texts(
|
||||
texts=["cc", "dd"],
|
||||
metadatas=[
|
||||
{"k": "c_new", "ord": 102},
|
||||
{"k": "d_new", "ord": 103},
|
||||
],
|
||||
ids=["c", "d"],
|
||||
)
|
||||
res2 = store_someemb.similarity_search("Abc", k=10)
|
||||
assert len(res2) == 4
|
||||
# pick one that was just updated and check its metadata
|
||||
res3 = store_someemb.similarity_search_with_score_id(
|
||||
query="cc", k=1, filter={"k": "c_new"}
|
||||
)
|
||||
print(str(res3))
|
||||
doc3, score3, id3 = res3[0]
|
||||
assert doc3.page_content == "cc"
|
||||
assert doc3.metadata == {"k": "c_new", "ord": 102}
|
||||
assert score3 > 0.999 # leaving some leeway for approximations...
|
||||
assert id3 == "c"
|
||||
# delete and count again
|
||||
del1_res = store_someemb.delete(["b"])
|
||||
assert del1_res is True
|
||||
del2_res = store_someemb.delete(["a", "c", "Z!"])
|
||||
assert del2_res is True # a non-existing ID was supplied
|
||||
assert len(store_someemb.similarity_search("xy", k=10)) == 1
|
||||
# clear store
|
||||
store_someemb.clear()
|
||||
assert store_someemb.similarity_search("Abc", k=2) == []
|
||||
# add_documents with "ids" arg passthrough
|
||||
store_someemb.add_documents(
|
||||
[
|
||||
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
|
||||
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
|
||||
],
|
||||
ids=["v", "w"],
|
||||
)
|
||||
assert len(store_someemb.similarity_search("xy", k=10)) == 2
|
||||
res4 = store_someemb.similarity_search("ww", k=1, filter={"k": "w"})
|
||||
assert res4[0].metadata["ord"] == 205
|
||||
|
||||
async def test_astradb_vectorstore_crud_async(
|
||||
self, store_someemb: AstraDBVectorStore
|
||||
) -> None:
|
||||
"""Basic add/delete/update behaviour."""
|
||||
res0 = await store_someemb.asimilarity_search("Abc", k=2)
|
||||
assert res0 == []
|
||||
# write and check again
|
||||
await store_someemb.aadd_texts(
|
||||
texts=["aa", "bb", "cc"],
|
||||
metadatas=[
|
||||
{"k": "a", "ord": 0},
|
||||
{"k": "b", "ord": 1},
|
||||
{"k": "c", "ord": 2},
|
||||
],
|
||||
ids=["a", "b", "c"],
|
||||
)
|
||||
res1 = await store_someemb.asimilarity_search("Abc", k=5)
|
||||
assert {doc.page_content for doc in res1} == {"aa", "bb", "cc"}
|
||||
# partial overwrite and count total entries
|
||||
await store_someemb.aadd_texts(
|
||||
texts=["cc", "dd"],
|
||||
metadatas=[
|
||||
{"k": "c_new", "ord": 102},
|
||||
{"k": "d_new", "ord": 103},
|
||||
],
|
||||
ids=["c", "d"],
|
||||
)
|
||||
res2 = await store_someemb.asimilarity_search("Abc", k=10)
|
||||
assert len(res2) == 4
|
||||
# pick one that was just updated and check its metadata
|
||||
res3 = await store_someemb.asimilarity_search_with_score_id(
|
||||
query="cc", k=1, filter={"k": "c_new"}
|
||||
)
|
||||
print(str(res3))
|
||||
doc3, score3, id3 = res3[0]
|
||||
assert doc3.page_content == "cc"
|
||||
assert doc3.metadata == {"k": "c_new", "ord": 102}
|
||||
assert score3 > 0.999 # leaving some leeway for approximations...
|
||||
assert id3 == "c"
|
||||
# delete and count again
|
||||
del1_res = await store_someemb.adelete(["b"])
|
||||
assert del1_res is True
|
||||
del2_res = await store_someemb.adelete(["a", "c", "Z!"])
|
||||
assert del2_res is False # a non-existing ID was supplied
|
||||
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 1
|
||||
# clear store
|
||||
await store_someemb.aclear()
|
||||
assert await store_someemb.asimilarity_search("Abc", k=2) == []
|
||||
# add_documents with "ids" arg passthrough
|
||||
await store_someemb.aadd_documents(
|
||||
[
|
||||
Document(page_content="vv", metadata={"k": "v", "ord": 204}),
|
||||
Document(page_content="ww", metadata={"k": "w", "ord": 205}),
|
||||
],
|
||||
ids=["v", "w"],
|
||||
)
|
||||
assert len(await store_someemb.asimilarity_search("xy", k=10)) == 2
|
||||
res4 = await store_someemb.asimilarity_search("ww", k=1, filter={"k": "w"})
|
||||
assert res4[0].metadata["ord"] == 205
|
||||
|
||||
def test_astradb_vectorstore_mmr(self, store_parseremb: AstraDBVectorStore) -> None:
|
||||
"""
|
||||
MMR testing. We work on the unit circle with angle multiples
|
||||
of 2*pi/20 and prepare a store with known vectors for a controlled
|
||||
MMR outcome.
|
||||
"""
|
||||
|
||||
def _v_from_i(i: int, N: int) -> str:
|
||||
angle = 2 * math.pi * i / N
|
||||
vector = [math.cos(angle), math.sin(angle)]
|
||||
return json.dumps(vector)
|
||||
|
||||
i_vals = [0, 4, 5, 13]
|
||||
N_val = 20
|
||||
store_parseremb.add_texts(
|
||||
[_v_from_i(i, N_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals]
|
||||
)
|
||||
res1 = store_parseremb.max_marginal_relevance_search(
|
||||
_v_from_i(3, N_val),
|
||||
k=2,
|
||||
fetch_k=3,
|
||||
)
|
||||
res_i_vals = {doc.metadata["i"] for doc in res1}
|
||||
assert res_i_vals == {0, 4}
|
||||
|
||||
async def test_astradb_vectorstore_mmr_async(
|
||||
self, store_parseremb: AstraDBVectorStore
|
||||
) -> None:
|
||||
"""
|
||||
MMR testing. We work on the unit circle with angle multiples
|
||||
of 2*pi/20 and prepare a store with known vectors for a controlled
|
||||
MMR outcome.
|
||||
"""
|
||||
|
||||
def _v_from_i(i: int, N: int) -> str:
|
||||
angle = 2 * math.pi * i / N
|
||||
vector = [math.cos(angle), math.sin(angle)]
|
||||
return json.dumps(vector)
|
||||
|
||||
i_vals = [0, 4, 5, 13]
|
||||
N_val = 20
|
||||
await store_parseremb.aadd_texts(
|
||||
[_v_from_i(i, N_val) for i in i_vals],
|
||||
metadatas=[{"i": i} for i in i_vals],
|
||||
)
|
||||
res1 = await store_parseremb.amax_marginal_relevance_search(
|
||||
_v_from_i(3, N_val),
|
||||
k=2,
|
||||
fetch_k=3,
|
||||
)
|
||||
res_i_vals = {doc.metadata["i"] for doc in res1}
|
||||
assert res_i_vals == {0, 4}
|
||||
|
||||
def test_astradb_vectorstore_metadata(
|
||||
self, store_someemb: AstraDBVectorStore
|
||||
) -> None:
|
||||
"""Metadata filtering."""
|
||||
store_someemb.add_documents(
|
||||
[
|
||||
Document(
|
||||
page_content="q",
|
||||
metadata={"ord": ord("q"), "group": "consonant"},
|
||||
),
|
||||
Document(
|
||||
page_content="w",
|
||||
metadata={"ord": ord("w"), "group": "consonant"},
|
||||
),
|
||||
Document(
|
||||
page_content="r",
|
||||
metadata={"ord": ord("r"), "group": "consonant"},
|
||||
),
|
||||
Document(
|
||||
page_content="e",
|
||||
metadata={"ord": ord("e"), "group": "vowel"},
|
||||
),
|
||||
Document(
|
||||
page_content="i",
|
||||
metadata={"ord": ord("i"), "group": "vowel"},
|
||||
),
|
||||
Document(
|
||||
page_content="o",
|
||||
metadata={"ord": ord("o"), "group": "vowel"},
|
||||
),
|
||||
]
|
||||
)
|
||||
# no filters
|
||||
res0 = store_someemb.similarity_search("x", k=10)
|
||||
assert {doc.page_content for doc in res0} == set("qwreio")
|
||||
# single filter
|
||||
res1 = store_someemb.similarity_search(
|
||||
"x",
|
||||
k=10,
|
||||
filter={"group": "vowel"},
|
||||
)
|
||||
assert {doc.page_content for doc in res1} == set("eio")
|
||||
# multiple filters
|
||||
res2 = store_someemb.similarity_search(
|
||||
"x",
|
||||
k=10,
|
||||
filter={"group": "consonant", "ord": ord("q")},
|
||||
)
|
||||
assert {doc.page_content for doc in res2} == set("q")
|
||||
# excessive filters
|
||||
res3 = store_someemb.similarity_search(
|
||||
"x",
|
||||
k=10,
|
||||
filter={"group": "consonant", "ord": ord("q"), "case": "upper"},
|
||||
)
|
||||
assert res3 == []
|
||||
# filter with logical operator
|
||||
res4 = store_someemb.similarity_search(
|
||||
"x",
|
||||
k=10,
|
||||
filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]},
|
||||
)
|
||||
assert {doc.page_content for doc in res4} == {"q", "r"}
|
||||
|
||||
def test_astradb_vectorstore_similarity_scale(
|
||||
self, store_parseremb: AstraDBVectorStore
|
||||
) -> None:
|
||||
"""Scale of the similarity scores."""
|
||||
store_parseremb.add_texts(
|
||||
texts=[
|
||||
json.dumps([1, 1]),
|
||||
json.dumps([-1, -1]),
|
||||
],
|
||||
ids=["near", "far"],
|
||||
)
|
||||
res1 = store_parseremb.similarity_search_with_score(
|
||||
json.dumps([0.5, 0.5]),
|
||||
k=2,
|
||||
)
|
||||
scores = [sco for _, sco in res1]
|
||||
sco_near, sco_far = scores
|
||||
assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON
|
||||
|
||||
async def test_astradb_vectorstore_similarity_scale_async(
|
||||
self, store_parseremb: AstraDBVectorStore
|
||||
) -> None:
|
||||
"""Scale of the similarity scores."""
|
||||
await store_parseremb.aadd_texts(
|
||||
texts=[
|
||||
json.dumps([1, 1]),
|
||||
json.dumps([-1, -1]),
|
||||
],
|
||||
ids=["near", "far"],
|
||||
)
|
||||
res1 = await store_parseremb.asimilarity_search_with_score(
|
||||
json.dumps([0.5, 0.5]),
|
||||
k=2,
|
||||
)
|
||||
scores = [sco for _, sco in res1]
|
||||
sco_near, sco_far = scores
|
||||
assert abs(1 - sco_near) < MATCH_EPSILON and abs(sco_far) < MATCH_EPSILON
|
||||
|
||||
def test_astradb_vectorstore_massive_delete(
|
||||
self, store_someemb: AstraDBVectorStore
|
||||
) -> None:
|
||||
"""Larger-scale bulk deletes."""
|
||||
M = 50
|
||||
texts = [str(i + 1 / 7.0) for i in range(2 * M)]
|
||||
ids0 = ["doc_%i" % i for i in range(M)]
|
||||
ids1 = ["doc_%i" % (i + M) for i in range(M)]
|
||||
ids = ids0 + ids1
|
||||
store_someemb.add_texts(texts=texts, ids=ids)
|
||||
# deleting a bunch of these
|
||||
del_res0 = store_someemb.delete(ids0)
|
||||
assert del_res0 is True
|
||||
# deleting the rest plus a fake one
|
||||
del_res1 = store_someemb.delete(ids1 + ["ghost!"])
|
||||
assert del_res1 is True # ensure no error
|
||||
# nothing left
|
||||
assert store_someemb.similarity_search("x", k=2 * M) == []
|
||||
|
||||
@pytest.mark.skipif(
|
||||
SKIP_COLLECTION_DELETE,
|
||||
reason="Collection-deletion tests are suppressed",
|
||||
)
|
||||
def test_astradb_vectorstore_delete_collection(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""behaviour of 'delete_collection'."""
|
||||
collection_name = COLLECTION_NAME_DIM2
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=collection_name,
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store.add_texts(["huh"])
|
||||
assert len(v_store.similarity_search("hah", k=10)) == 1
|
||||
# another instance pointing to the same collection on DB
|
||||
v_store_kenny = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=collection_name,
|
||||
**astradb_credentials,
|
||||
)
|
||||
v_store_kenny.delete_collection()
|
||||
# dropped on DB, but 'v_store' should have no clue:
|
||||
with pytest.raises(ValueError):
|
||||
_ = v_store.similarity_search("hah", k=10)
|
||||
|
||||
def test_astradb_vectorstore_custom_params(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""Custom batch size and concurrency params."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
# prepare empty collection
|
||||
AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
).clear()
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
batch_size=17,
|
||||
bulk_insert_batch_concurrency=13,
|
||||
bulk_insert_overwrite_concurrency=7,
|
||||
bulk_delete_concurrency=19,
|
||||
)
|
||||
try:
|
||||
# add_texts
|
||||
N = 50
|
||||
texts = [str(i + 1 / 7.0) for i in range(N)]
|
||||
ids = ["doc_%i" % i for i in range(N)]
|
||||
v_store.add_texts(texts=texts, ids=ids)
|
||||
v_store.add_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
batch_size=19,
|
||||
batch_concurrency=7,
|
||||
overwrite_concurrency=13,
|
||||
)
|
||||
#
|
||||
_ = v_store.delete(ids[: N // 2])
|
||||
_ = v_store.delete(ids[N // 2 :], concurrency=23)
|
||||
#
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
v_store.delete_collection()
|
||||
else:
|
||||
v_store.clear()
|
||||
|
||||
async def test_astradb_vectorstore_custom_params_async(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""Custom batch size and concurrency params."""
|
||||
emb = SomeEmbeddings(dimension=2)
|
||||
v_store = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name="lc_test_c_async",
|
||||
batch_size=17,
|
||||
bulk_insert_batch_concurrency=13,
|
||||
bulk_insert_overwrite_concurrency=7,
|
||||
bulk_delete_concurrency=19,
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
# add_texts
|
||||
N = 50
|
||||
texts = [str(i + 1 / 7.0) for i in range(N)]
|
||||
ids = ["doc_%i" % i for i in range(N)]
|
||||
await v_store.aadd_texts(texts=texts, ids=ids)
|
||||
await v_store.aadd_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
batch_size=19,
|
||||
batch_concurrency=7,
|
||||
overwrite_concurrency=13,
|
||||
)
|
||||
#
|
||||
await v_store.adelete(ids[: N // 2])
|
||||
await v_store.adelete(ids[N // 2 :], concurrency=23)
|
||||
#
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
await v_store.adelete_collection()
|
||||
else:
|
||||
await v_store.aclear()
|
||||
|
||||
def test_astradb_vectorstore_metrics(
|
||||
self, astradb_credentials: AstraDBCredentials
|
||||
) -> None:
|
||||
"""
|
||||
Different choices of similarity metric.
|
||||
Both stores (with "cosine" and "euclidea" metrics) contain these two:
|
||||
- a vector slightly rotated w.r.t query vector
|
||||
- a vector which is a long multiple of query vector
|
||||
so, which one is "the closest one" depends on the metric.
|
||||
"""
|
||||
emb = ParserEmbeddings(dimension=2)
|
||||
isq2 = 0.5**0.5
|
||||
isa = 0.7
|
||||
isb = (1.0 - isa * isa) ** 0.5
|
||||
texts = [
|
||||
json.dumps([isa, isb]),
|
||||
json.dumps([10 * isq2, 10 * isq2]),
|
||||
]
|
||||
ids = [
|
||||
"rotated",
|
||||
"scaled",
|
||||
]
|
||||
query_text = json.dumps([isq2, isq2])
|
||||
|
||||
# prepare empty collections
|
||||
AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
**astradb_credentials,
|
||||
).clear()
|
||||
AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN,
|
||||
metric="euclidean",
|
||||
**astradb_credentials,
|
||||
).clear()
|
||||
|
||||
# creation, population, query - cosine
|
||||
vstore_cos = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2,
|
||||
metric="cosine",
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
vstore_cos.add_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
)
|
||||
_, _, id_from_cos = vstore_cos.similarity_search_with_score_id(
|
||||
query_text,
|
||||
k=1,
|
||||
)[0]
|
||||
assert id_from_cos == "scaled"
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
vstore_cos.delete_collection()
|
||||
else:
|
||||
vstore_cos.clear()
|
||||
# creation, population, query - euclidean
|
||||
|
||||
vstore_euc = AstraDBVectorStore(
|
||||
embedding=emb,
|
||||
collection_name=COLLECTION_NAME_DIM2_EUCLIDEAN,
|
||||
metric="euclidean",
|
||||
**astradb_credentials,
|
||||
)
|
||||
try:
|
||||
vstore_euc.add_texts(
|
||||
texts=texts,
|
||||
ids=ids,
|
||||
)
|
||||
_, _, id_from_euc = vstore_euc.similarity_search_with_score_id(
|
||||
query_text,
|
||||
k=1,
|
||||
)[0]
|
||||
assert id_from_euc == "rotated"
|
||||
finally:
|
||||
if not SKIP_COLLECTION_DELETE:
|
||||
vstore_euc.delete_collection()
|
||||
else:
|
||||
vstore_euc.clear()
|
@ -1,12 +0,0 @@
|
||||
from langchain_astradb import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AstraDBByteStore",
|
||||
"AstraDBStore",
|
||||
"AstraDBChatMessageHistory",
|
||||
"AstraDBVectorStore",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -1,45 +0,0 @@
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_astradb.vectorstores import AstraDBVectorStore
|
||||
|
||||
|
||||
class SomeEmbeddings(Embeddings):
|
||||
"""
|
||||
Turn a sentence into an embedding vector in some way.
|
||||
Not important how. It is deterministic is all that counts.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int) -> None:
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return [self.embed_query(txt) for txt in texts]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
unnormed0 = [ord(c) for c in text[: self.dimension]]
|
||||
unnormed = (unnormed0 + [1] + [0] * (self.dimension - 1 - len(unnormed0)))[
|
||||
: self.dimension
|
||||
]
|
||||
norm = sum(x * x for x in unnormed) ** 0.5
|
||||
normed = [x / norm for x in unnormed]
|
||||
return normed
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration vectorstore initialization."""
|
||||
mock_astra_db = Mock()
|
||||
embedding = SomeEmbeddings(dimension=2)
|
||||
AstraDBVectorStore(
|
||||
embedding=embedding,
|
||||
collection_name="mock_coll_name",
|
||||
astra_db_client=mock_astra_db,
|
||||
)
|
Loading…
Reference in New Issue