langchain/libs/community/tests/unit_tests/utilities/test_cassandra_database.py
Patrick McFadin 3331865f6b
community[minor]: add Cassandra Database Toolkit (#20246)
**Description**: ToolKit and Tools for accessing data in a Cassandra
Database primarily for Agent integration. Initially, this includes the
following tools:
- `cassandra_db_schema` Gathers all schema information for the connected
database or a specific schema. Critical for the agent when determining
actions.
- `cassandra_db_select_table_data` Selects data from a specific keyspace
and table. The agent can pass paramaters for a predicate and limits on
the number of returned records.
- `cassandra_db_query` Expiriemental alternative to
`cassandra_db_select_table_data` which takes a query string completely
formed by the agent instead of parameters. May be removed in future
versions.

Includes unit test and two notebooks to demonstrate usage. 

**Dependencies**: cassio
**Twitter handle**: @PatrickMcFadin

---------

Co-authored-by: Phil Miesle <phil.miesle@datastax.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-04-29 15:51:43 +00:00

86 lines
3.0 KiB
Python

from collections import namedtuple
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from langchain_community.utilities.cassandra_database import (
CassandraDatabase,
DatabaseError,
Table,
)
# Define a namedtuple type
MockRow = namedtuple("MockRow", ["col1", "col2"])
class TestCassandraDatabase(object):
def setup_method(self) -> None:
self.mock_session = MagicMock()
self.cassandra_db = CassandraDatabase(session=self.mock_session)
def test_init_without_session(self) -> None:
with pytest.raises(ValueError):
CassandraDatabase()
def test_run_query(self) -> None:
# Mock the execute method to return an iterable of dictionaries directly
self.mock_session.execute.return_value = iter(
[{"col1": "val1", "col2": "val2"}]
)
# Execute the query
result = self.cassandra_db.run("SELECT * FROM table")
# Assert that the result is as expected
assert result == [{"col1": "val1", "col2": "val2"}]
# Verify that execute was called with the expected CQL query
self.mock_session.execute.assert_called_with("SELECT * FROM table")
def test_run_query_cursor(self) -> None:
mock_result_set = MagicMock()
self.mock_session.execute.return_value = mock_result_set
result = self.cassandra_db.run("SELECT * FROM table;", fetch="cursor")
assert result == mock_result_set
def test_run_query_invalid_fetch(self) -> None:
with pytest.raises(ValueError):
self.cassandra_db.run("SELECT * FROM table;", fetch="invalid")
def test_validate_cql_select(self) -> None:
query = "SELECT * FROM table;"
result = self.cassandra_db._validate_cql(query, "SELECT")
assert result == "SELECT * FROM table"
def test_validate_cql_unsupported_type(self) -> None:
query = "UPDATE table SET col=val;"
with pytest.raises(ValueError):
self.cassandra_db._validate_cql(query, "UPDATE")
def test_validate_cql_unsafe(self) -> None:
query = "SELECT * FROM table; DROP TABLE table;"
with pytest.raises(DatabaseError):
self.cassandra_db._validate_cql(query, "SELECT")
@patch(
"langchain_community.utilities.cassandra_database.CassandraDatabase._resolve_schema"
)
def test_format_schema_to_markdown(self, mock_resolve_schema: Any) -> None:
mock_table1 = MagicMock(spec=Table)
mock_table1.as_markdown.return_value = "## Keyspace: keyspace1"
mock_table2 = MagicMock(spec=Table)
mock_table2.as_markdown.return_value = "## Keyspace: keyspace2"
mock_resolve_schema.return_value = {
"keyspace1": [mock_table1],
"keyspace2": [mock_table2],
}
markdown = self.cassandra_db.format_schema_to_markdown()
assert markdown.startswith("# Cassandra Database Schema")
assert "## Keyspace: keyspace1" in markdown
assert "## Keyspace: keyspace2" in markdown
if __name__ == "__main__":
pytest.main()