mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
3331865f6b
**Description**: ToolKit and Tools for accessing data in a Cassandra Database primarily for Agent integration. Initially, this includes the following tools: - `cassandra_db_schema` Gathers all schema information for the connected database or a specific schema. Critical for the agent when determining actions. - `cassandra_db_select_table_data` Selects data from a specific keyspace and table. The agent can pass paramaters for a predicate and limits on the number of returned records. - `cassandra_db_query` Expiriemental alternative to `cassandra_db_select_table_data` which takes a query string completely formed by the agent instead of parameters. May be removed in future versions. Includes unit test and two notebooks to demonstrate usage. **Dependencies**: cassio **Twitter handle**: @PatrickMcFadin --------- Co-authored-by: Phil Miesle <phil.miesle@datastax.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
from collections import namedtuple
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from langchain_community.utilities.cassandra_database import (
|
|
CassandraDatabase,
|
|
DatabaseError,
|
|
Table,
|
|
)
|
|
|
|
# Define a namedtuple type
|
|
MockRow = namedtuple("MockRow", ["col1", "col2"])
|
|
|
|
|
|
class TestCassandraDatabase(object):
|
|
def setup_method(self) -> None:
|
|
self.mock_session = MagicMock()
|
|
self.cassandra_db = CassandraDatabase(session=self.mock_session)
|
|
|
|
def test_init_without_session(self) -> None:
|
|
with pytest.raises(ValueError):
|
|
CassandraDatabase()
|
|
|
|
def test_run_query(self) -> None:
|
|
# Mock the execute method to return an iterable of dictionaries directly
|
|
self.mock_session.execute.return_value = iter(
|
|
[{"col1": "val1", "col2": "val2"}]
|
|
)
|
|
|
|
# Execute the query
|
|
result = self.cassandra_db.run("SELECT * FROM table")
|
|
|
|
# Assert that the result is as expected
|
|
assert result == [{"col1": "val1", "col2": "val2"}]
|
|
|
|
# Verify that execute was called with the expected CQL query
|
|
self.mock_session.execute.assert_called_with("SELECT * FROM table")
|
|
|
|
def test_run_query_cursor(self) -> None:
|
|
mock_result_set = MagicMock()
|
|
self.mock_session.execute.return_value = mock_result_set
|
|
result = self.cassandra_db.run("SELECT * FROM table;", fetch="cursor")
|
|
assert result == mock_result_set
|
|
|
|
def test_run_query_invalid_fetch(self) -> None:
|
|
with pytest.raises(ValueError):
|
|
self.cassandra_db.run("SELECT * FROM table;", fetch="invalid")
|
|
|
|
def test_validate_cql_select(self) -> None:
|
|
query = "SELECT * FROM table;"
|
|
result = self.cassandra_db._validate_cql(query, "SELECT")
|
|
assert result == "SELECT * FROM table"
|
|
|
|
def test_validate_cql_unsupported_type(self) -> None:
|
|
query = "UPDATE table SET col=val;"
|
|
with pytest.raises(ValueError):
|
|
self.cassandra_db._validate_cql(query, "UPDATE")
|
|
|
|
def test_validate_cql_unsafe(self) -> None:
|
|
query = "SELECT * FROM table; DROP TABLE table;"
|
|
with pytest.raises(DatabaseError):
|
|
self.cassandra_db._validate_cql(query, "SELECT")
|
|
|
|
@patch(
|
|
"langchain_community.utilities.cassandra_database.CassandraDatabase._resolve_schema"
|
|
)
|
|
def test_format_schema_to_markdown(self, mock_resolve_schema: Any) -> None:
|
|
mock_table1 = MagicMock(spec=Table)
|
|
mock_table1.as_markdown.return_value = "## Keyspace: keyspace1"
|
|
mock_table2 = MagicMock(spec=Table)
|
|
mock_table2.as_markdown.return_value = "## Keyspace: keyspace2"
|
|
mock_resolve_schema.return_value = {
|
|
"keyspace1": [mock_table1],
|
|
"keyspace2": [mock_table2],
|
|
}
|
|
markdown = self.cassandra_db.format_schema_to_markdown()
|
|
assert markdown.startswith("# Cassandra Database Schema")
|
|
assert "## Keyspace: keyspace1" in markdown
|
|
assert "## Keyspace: keyspace2" in markdown
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main()
|