mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
community[patch]: Refactor CassandraDatabase wrapper (#21075)
* Introduce individual `fetch_` methods for easier typing. * Rework some docstrings to google style * Move some logic to the tool * Merge the 2 cassandra utility files
This commit is contained in:
parent
b00fd1dbde
commit
683fb45c6b
@ -1,6 +1,7 @@
|
||||
"""Tools for interacting with an Apache Cassandra database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
@ -43,7 +44,11 @@ class QueryCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool):
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Union[str, Sequence[Dict[str, Any]], ResultSet]:
|
||||
"""Execute the query, return the results or an error message."""
|
||||
return self.db.run_no_throw(query)
|
||||
try:
|
||||
return self.db.run(query)
|
||||
except Exception as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}\n{traceback.format_exc()}"
|
||||
|
||||
|
||||
class _GetSchemaCassandraDatabaseToolInput(BaseModel):
|
||||
@ -73,7 +78,12 @@ class GetSchemaCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool):
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get the schema for a keyspace."""
|
||||
return self.db.get_keyspace_tables_str_no_throw(keyspace)
|
||||
try:
|
||||
tables = self.db.get_keyspace_tables(keyspace)
|
||||
return "".join([table.as_markdown() + "\n\n" for table in tables])
|
||||
except Exception as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}\n{traceback.format_exc()}"
|
||||
|
||||
|
||||
class _GetTableDataCassandraDatabaseToolInput(BaseModel):
|
||||
@ -123,4 +133,8 @@ class GetTableDataCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool):
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Get data from a table in a keyspace."""
|
||||
return self.db.get_table_data_no_throw(keyspace, table, predicate, limit)
|
||||
try:
|
||||
return self.db.get_table_data(keyspace, table, predicate, limit)
|
||||
except Exception as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}\n{traceback.format_exc()}"
|
||||
|
@ -2,7 +2,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
@ -32,9 +31,10 @@ class CassandraDatabase:
|
||||
include_tables: Optional[List[str]] = None,
|
||||
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self._session = self._resolve_session(session, cassio_init_kwargs)
|
||||
if not self._session:
|
||||
_session = self._resolve_session(session, cassio_init_kwargs)
|
||||
if not _session:
|
||||
raise ValueError("Session not provided and cannot be resolved")
|
||||
self._session = _session
|
||||
|
||||
self._exclude_keyspaces = IGNORED_KEYSPACES
|
||||
self._exclude_tables = exclude_tables or []
|
||||
@ -44,52 +44,28 @@ class CassandraDatabase:
|
||||
self,
|
||||
query: str,
|
||||
fetch: str = "all",
|
||||
include_columns: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, Sequence[Dict[str, Any]], ResultSet]:
|
||||
) -> Union[list, Dict[str, Any], ResultSet]:
|
||||
"""Execute a CQL query and return the results."""
|
||||
clean_query = self._validate_cql(query, "SELECT")
|
||||
result = self._session.execute(clean_query, **kwargs)
|
||||
if fetch == "all":
|
||||
return list(result)
|
||||
return self.fetch_all(query, **kwargs)
|
||||
elif fetch == "one":
|
||||
return result.one()._asdict() if result else {}
|
||||
return self.fetch_one(query, **kwargs)
|
||||
elif fetch == "cursor":
|
||||
return result
|
||||
return self._fetch(query, **kwargs)
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'")
|
||||
|
||||
def run_no_throw(
|
||||
self,
|
||||
query: str,
|
||||
fetch: str = "all",
|
||||
include_columns: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[str, Sequence[Dict[str, Any]], ResultSet]:
|
||||
"""Execute a CQL query and return the results or an error message."""
|
||||
try:
|
||||
return self.run(query, fetch, include_columns, **kwargs)
|
||||
except Exception as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}\n{traceback.format_exc()}"
|
||||
def _fetch(self, query: str, **kwargs: Any) -> ResultSet:
|
||||
clean_query = self._validate_cql(query, "SELECT")
|
||||
return self._session.execute(clean_query, **kwargs)
|
||||
|
||||
def get_keyspace_tables_str_no_throw(self, keyspace: str) -> str:
|
||||
"""Get the tables for the specified keyspace."""
|
||||
try:
|
||||
schema_string = self.get_keyspace_tables_str(keyspace)
|
||||
return schema_string
|
||||
except Exception as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}\n{traceback.format_exc()}"
|
||||
def fetch_all(self, query: str, **kwargs: Any) -> list:
|
||||
return list(self._fetch(query, **kwargs))
|
||||
|
||||
def get_keyspace_tables_str(self, keyspace: str) -> str:
|
||||
"""Get the tables for the specified keyspace."""
|
||||
tables = self.get_keyspace_tables(keyspace)
|
||||
schema_string = ""
|
||||
for table in tables:
|
||||
schema_string += table.as_markdown() + "\n\n"
|
||||
|
||||
return schema_string
|
||||
def fetch_one(self, query: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
result = self._fetch(query, **kwargs)
|
||||
return result.one()._asdict() if result else {}
|
||||
|
||||
def get_keyspace_tables(self, keyspace: str) -> List[Table]:
|
||||
"""Get the Table objects for the specified keyspace."""
|
||||
@ -99,17 +75,6 @@ class CassandraDatabase:
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_table_data_no_throw(
|
||||
self, keyspace: str, table: str, predicate: str, limit: int
|
||||
) -> str:
|
||||
"""Get data from the specified table in the specified keyspace. Optionally can
|
||||
take a predicate for the WHERE clause and a limit."""
|
||||
try:
|
||||
return self.get_table_data(keyspace, table, predicate, limit)
|
||||
except Exception as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}\n{traceback.format_exc()}"
|
||||
|
||||
# This is a more basic string building function that doesn't use a query builder
|
||||
# or prepared statements
|
||||
# TODO: Refactor to use prepared statements
|
||||
@ -127,7 +92,7 @@ class CassandraDatabase:
|
||||
|
||||
query += ";"
|
||||
|
||||
result = self.run(query, fetch="all")
|
||||
result = self.fetch_all(query)
|
||||
data = "\n".join(str(row) for row in result)
|
||||
return data
|
||||
|
||||
@ -144,15 +109,13 @@ class CassandraDatabase:
|
||||
by iterating over all tables within that keyspace and calling their
|
||||
as_markdown method.
|
||||
|
||||
Parameters:
|
||||
- keyspace (str): The name of the keyspace to generate markdown
|
||||
documentation for.
|
||||
- tables (list[Table]): list of tables in the keyspace; it will be resolved
|
||||
if not provided.
|
||||
Args:
|
||||
keyspace: The name of the keyspace to generate markdown documentation for.
|
||||
tables: list of tables in the keyspace; it will be resolved if not provided.
|
||||
|
||||
Returns:
|
||||
A string containing the markdown representation of the specified
|
||||
keyspace schema.
|
||||
A string containing the markdown representation of the specified
|
||||
keyspace schema.
|
||||
"""
|
||||
if not tables:
|
||||
tables = self.get_keyspace_tables(keyspace)
|
||||
@ -184,10 +147,10 @@ class CassandraDatabase:
|
||||
the subset of keyspaces that have been resolved in this instance.
|
||||
|
||||
Returns:
|
||||
A markdown string that documents the schema of all resolved keyspaces and
|
||||
their tables within this CassandraDatabase instance. This includes keyspace
|
||||
names, table names, comments, columns, partition keys, clustering keys,
|
||||
and indexes for each table.
|
||||
A markdown string that documents the schema of all resolved keyspaces and
|
||||
their tables within this CassandraDatabase instance. This includes keyspace
|
||||
names, table names, comments, columns, partition keys, clustering keys,
|
||||
and indexes for each table.
|
||||
"""
|
||||
schema = self._resolve_schema()
|
||||
output = "# Cassandra Database Schema\n\n"
|
||||
@ -201,18 +164,18 @@ class CassandraDatabase:
|
||||
Ensures that `cql` starts with the specified type (e.g., SELECT) and does
|
||||
not contain content that could indicate CQL injection vulnerabilities.
|
||||
|
||||
Parameters:
|
||||
- cql (str): The CQL query string to be validated.
|
||||
- type (str): The expected starting keyword of the query, used to verify
|
||||
that the query begins with the correct operation type
|
||||
(e.g., "SELECT", "UPDATE"). Defaults to "SELECT".
|
||||
Args:
|
||||
cql: The CQL query string to be validated.
|
||||
type: The expected starting keyword of the query, used to verify
|
||||
that the query begins with the correct operation type
|
||||
(e.g., "SELECT", "UPDATE"). Defaults to "SELECT".
|
||||
|
||||
Returns:
|
||||
- str: The trimmed and validated CQL query string without a trailing semicolon.
|
||||
The trimmed and validated CQL query string without a trailing semicolon.
|
||||
|
||||
Raises:
|
||||
- ValueError: If the value of `type` is not supported
|
||||
- DatabaseError: If `cql` is considered unsafe
|
||||
ValueError: If the value of `type` is not supported
|
||||
DatabaseError: If `cql` is considered unsafe
|
||||
"""
|
||||
SUPPORTED_TYPES = ["SELECT"]
|
||||
if type and type.upper() not in SUPPORTED_TYPES:
|
||||
@ -246,29 +209,26 @@ class CassandraDatabase:
|
||||
# The trimmed query, before modifications
|
||||
return cql_trimmed
|
||||
|
||||
def _fetch_keyspaces(self, keyspace_list: Optional[List[str]] = None) -> List[str]:
|
||||
def _fetch_keyspaces(self, keyspaces: Optional[List[str]] = None) -> List[str]:
|
||||
"""
|
||||
Fetches a list of keyspace names from the Cassandra database. The list can be
|
||||
filtered by a provided list of keyspace names or by excluding predefined
|
||||
keyspaces.
|
||||
|
||||
Parameters:
|
||||
- keyspace_list (Optional[List[str]]): A list of keyspace names to specifically
|
||||
include. If provided and not empty, the method returns only the keyspaces
|
||||
present in this list. If not provided or empty, the method returns all
|
||||
keyspaces except those specified in the _exclude_keyspaces attribute.
|
||||
Args:
|
||||
keyspaces: A list of keyspace names to specifically include.
|
||||
If provided and not empty, the method returns only the keyspaces
|
||||
present in this list.
|
||||
If not provided or empty, the method returns all keyspaces except those
|
||||
specified in the _exclude_keyspaces attribute.
|
||||
|
||||
Returns:
|
||||
- List[str]: A list of keyspace names according to the filtering criteria.
|
||||
A list of keyspace names according to the filtering criteria.
|
||||
"""
|
||||
all_keyspaces = self.run(
|
||||
"SELECT keyspace_name FROM system_schema.keyspaces", fetch="all"
|
||||
all_keyspaces = self.fetch_all(
|
||||
"SELECT keyspace_name FROM system_schema.keyspaces"
|
||||
)
|
||||
|
||||
# Type check to ensure 'all_keyspaces' is a sequence of dictionaries
|
||||
if not isinstance(all_keyspaces, Sequence):
|
||||
raise TypeError("Expected a sequence of dictionaries from 'run' method.")
|
||||
|
||||
# Filtering keyspaces based on 'keyspace_list' and '_exclude_keyspaces'
|
||||
filtered_keyspaces = []
|
||||
for ks in all_keyspaces:
|
||||
@ -276,87 +236,105 @@ class CassandraDatabase:
|
||||
continue # Skip if the row is not a dictionary.
|
||||
|
||||
keyspace_name = ks["keyspace_name"]
|
||||
if keyspace_list and keyspace_name in keyspace_list:
|
||||
if keyspaces and keyspace_name in keyspaces:
|
||||
filtered_keyspaces.append(keyspace_name)
|
||||
elif not keyspace_list and keyspace_name not in self._exclude_keyspaces:
|
||||
elif not keyspaces and keyspace_name not in self._exclude_keyspaces:
|
||||
filtered_keyspaces.append(keyspace_name)
|
||||
|
||||
return filtered_keyspaces
|
||||
|
||||
def _fetch_schema_data(self, keyspace_list: List[str]) -> Tuple:
|
||||
"""
|
||||
Fetches schema data, including tables, columns, and indexes, filtered by a
|
||||
list of keyspaces. This method constructs CQL queries to retrieve detailed
|
||||
schema information from the specified keyspaces and executes them to gather
|
||||
data about tables, columns, and indexes within those keyspaces.
|
||||
def _format_keyspace_query(self, query: str, keyspaces: List[str]) -> str:
|
||||
# Construct IN clause for CQL query
|
||||
keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspaces])
|
||||
return f"""{query} WHERE keyspace_name IN ({keyspace_in_clause})"""
|
||||
|
||||
Parameters:
|
||||
- keyspace_list (List[str]): A list of keyspace names from which to fetch
|
||||
schema data.
|
||||
def _fetch_tables_data(self, keyspaces: List[str]) -> list:
|
||||
"""Fetches tables schema data, filtered by a list of keyspaces.
|
||||
This method allows for efficiently fetching schema information for multiple
|
||||
keyspaces in a single operation, enabling applications to programmatically
|
||||
analyze or document the database schema.
|
||||
|
||||
Args:
|
||||
keyspaces: A list of keyspace names from which to fetch tables schema data.
|
||||
|
||||
Returns:
|
||||
- Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: A
|
||||
tuple containing three lists:
|
||||
- The first list contains dictionaries of table details (keyspace name,
|
||||
table name, and comment).
|
||||
- The second list contains dictionaries of column details (keyspace name,
|
||||
table name, column name, type, kind, and position).
|
||||
- The third list contains dictionaries of index details (keyspace name,
|
||||
table name, index name, kind, and options).
|
||||
|
||||
This method allows for efficiently fetching schema information for multiple
|
||||
keyspaces in a single operation,
|
||||
enabling applications to programmatically analyze or document the database
|
||||
schema.
|
||||
Dictionaries of table details (keyspace name, table name, and comment).
|
||||
"""
|
||||
# Construct IN clause for CQL query
|
||||
keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspace_list])
|
||||
tables_query = self._format_keyspace_query(
|
||||
"SELECT keyspace_name, table_name, comment FROM system_schema.tables",
|
||||
keyspaces,
|
||||
)
|
||||
return self.fetch_all(tables_query)
|
||||
|
||||
# Fetch filtered table details
|
||||
tables_query = f"""SELECT keyspace_name, table_name, comment
|
||||
FROM system_schema.tables
|
||||
WHERE keyspace_name
|
||||
IN ({keyspace_in_clause})"""
|
||||
def _fetch_columns_data(self, keyspaces: List[str]) -> list:
|
||||
"""Fetches columns schema data, filtered by a list of keyspaces.
|
||||
This method allows for efficiently fetching schema information for multiple
|
||||
keyspaces in a single operation, enabling applications to programmatically
|
||||
analyze or document the database schema.
|
||||
|
||||
tables_data = self.run(tables_query, fetch="all")
|
||||
Args:
|
||||
keyspaces: A list of keyspace names from which to fetch tables schema data.
|
||||
|
||||
# Fetch filtered column details
|
||||
columns_query = f"""SELECT keyspace_name, table_name, column_name, type,
|
||||
kind, clustering_order, position
|
||||
FROM system_schema.columns
|
||||
WHERE keyspace_name
|
||||
IN ({keyspace_in_clause})"""
|
||||
Returns:
|
||||
Dictionaries of column details (keyspace name, table name, column name,
|
||||
type, kind, and position).
|
||||
"""
|
||||
tables_query = self._format_keyspace_query(
|
||||
"""
|
||||
SELECT keyspace_name, table_name, column_name, type, kind,
|
||||
clustering_order, position
|
||||
FROM system_schema.columns
|
||||
""",
|
||||
keyspaces,
|
||||
)
|
||||
return self.fetch_all(tables_query)
|
||||
|
||||
columns_data = self.run(columns_query, fetch="all")
|
||||
def _fetch_indexes_data(self, keyspaces: List[str]) -> list:
|
||||
"""Fetches indexes schema data, filtered by a list of keyspaces.
|
||||
This method allows for efficiently fetching schema information for multiple
|
||||
keyspaces in a single operation, enabling applications to programmatically
|
||||
analyze or document the database schema.
|
||||
|
||||
# Fetch filtered index details
|
||||
indexes_query = f"""SELECT keyspace_name, table_name, index_name,
|
||||
kind, options
|
||||
FROM system_schema.indexes
|
||||
WHERE keyspace_name
|
||||
IN ({keyspace_in_clause})"""
|
||||
Args:
|
||||
keyspaces: A list of keyspace names from which to fetch tables schema data.
|
||||
|
||||
indexes_data = self.run(indexes_query, fetch="all")
|
||||
|
||||
return tables_data, columns_data, indexes_data
|
||||
Returns:
|
||||
Dictionaries of index details (keyspace name, table name, index name, kind,
|
||||
and options).
|
||||
"""
|
||||
tables_query = self._format_keyspace_query(
|
||||
"""
|
||||
SELECT keyspace_name, table_name, index_name,
|
||||
kind, options
|
||||
FROM system_schema.indexes
|
||||
""",
|
||||
keyspaces,
|
||||
)
|
||||
return self.fetch_all(tables_query)
|
||||
|
||||
def _resolve_schema(
|
||||
self, keyspace_list: Optional[List[str]] = None
|
||||
self, keyspaces: Optional[List[str]] = None
|
||||
) -> Dict[str, List[Table]]:
|
||||
"""
|
||||
Efficiently fetches and organizes Cassandra table schema information,
|
||||
such as comments, columns, and indexes, into a dictionary mapping keyspace
|
||||
names to lists of Table objects.
|
||||
|
||||
Returns:
|
||||
A dictionary with keyspace names as keys and lists of Table objects as values,
|
||||
where each Table object is populated with schema details appropriate for its
|
||||
keyspace and table name.
|
||||
"""
|
||||
if not keyspace_list:
|
||||
keyspace_list = self._fetch_keyspaces()
|
||||
Args:
|
||||
keyspaces: An optional list of keyspace names from which to fetch tables
|
||||
schema data.
|
||||
|
||||
tables_data, columns_data, indexes_data = self._fetch_schema_data(keyspace_list)
|
||||
Returns:
|
||||
A dictionary with keyspace names as keys and lists of Table objects as
|
||||
values, where each Table object is populated with schema details
|
||||
appropriate for its keyspace and table name.
|
||||
"""
|
||||
if not keyspaces:
|
||||
keyspaces = self._fetch_keyspaces()
|
||||
|
||||
tables_data = self._fetch_tables_data(keyspaces)
|
||||
columns_data = self._fetch_columns_data(keyspaces)
|
||||
indexes_data = self._fetch_indexes_data(keyspaces)
|
||||
|
||||
keyspace_dict: dict = {}
|
||||
for table_data in tables_data:
|
||||
@ -415,11 +393,11 @@ class CassandraDatabase:
|
||||
|
||||
return keyspace_dict
|
||||
|
||||
@staticmethod
|
||||
def _resolve_session(
|
||||
self,
|
||||
session: Optional[Session] = None,
|
||||
cassio_init_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
) -> Optional[Session]:
|
||||
"""
|
||||
Attempts to resolve and return a Session object for use in database operations.
|
||||
|
||||
@ -430,18 +408,17 @@ class CassandraDatabase:
|
||||
3. A new `cassio` session derived from `cassio_init_kwargs`,
|
||||
4. `None`
|
||||
|
||||
Parameters:
|
||||
- session (Optional[Session]): An optional session to use directly.
|
||||
- cassio_init_kwargs (Optional[Dict[str, Any]]): An optional dictionary of
|
||||
keyword arguments to `cassio`.
|
||||
Args:
|
||||
session: An optional session to use directly.
|
||||
cassio_init_kwargs: An optional dictionary of keyword arguments to `cassio`.
|
||||
|
||||
Returns:
|
||||
- Session: The resolved session object if successful, or `None` if the session
|
||||
cannot be resolved.
|
||||
The resolved session object if successful, or `None` if the session
|
||||
cannot be resolved.
|
||||
|
||||
Raises:
|
||||
- ValueError: If `cassio_init_kwargs` is provided but is not a dictionary of
|
||||
keyword arguments.
|
||||
ValueError: If `cassio_init_kwargs` is provided but is not a dictionary of
|
||||
keyword arguments.
|
||||
"""
|
||||
|
||||
# Prefer given session
|
||||
@ -535,20 +512,18 @@ class Table(BaseModel):
|
||||
Generates a Markdown representation of the Cassandra table schema, allowing for
|
||||
customizable header levels for the table name section.
|
||||
|
||||
Parameters:
|
||||
- include_keyspace (bool): If True, includes the keyspace in the output.
|
||||
Defaults to True.
|
||||
- header_level (Optional[int]): Specifies the markdown header level for the
|
||||
table name.
|
||||
If None, the table name is included without a header. Defaults to None
|
||||
(no header level).
|
||||
Args:
|
||||
include_keyspace: If True, includes the keyspace in the output.
|
||||
Defaults to True.
|
||||
header_level: Specifies the markdown header level for the table name.
|
||||
If None, the table name is included without a header.
|
||||
Defaults to None (no header level).
|
||||
|
||||
Returns:
|
||||
- str: A string in Markdown format detailing the table name
|
||||
(with optional header level),
|
||||
keyspace (optional), comment, columns, partition keys, clustering keys
|
||||
(with optional clustering order),
|
||||
and indexes.
|
||||
A string in Markdown format detailing the table name
|
||||
(with optional header level), keyspace (optional), comment, columns,
|
||||
partition keys, clustering keys (with optional clustering order),
|
||||
and indexes.
|
||||
"""
|
||||
output = ""
|
||||
if header_level is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user