From 683fb45c6bc8b4b31cdbdb870a8c8ed5bc561f13 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Thu, 2 May 2024 19:13:08 +0200 Subject: [PATCH] community[patch]: Refactor CassandraDatabase wrapper (#21075) * Introduce individual `fetch_` methods for easier typing. * Rework some docstrings to google style * Move some logic to the tool * Merge the 2 cassandra utility files --- .../tools/cassandra_database/tool.py | 20 +- .../utilities/cassandra_database.py | 295 ++++++++---------- 2 files changed, 152 insertions(+), 163 deletions(-) diff --git a/libs/community/langchain_community/tools/cassandra_database/tool.py b/libs/community/langchain_community/tools/cassandra_database/tool.py index 0cb36355a9..337fca5327 100644 --- a/libs/community/langchain_community/tools/cassandra_database/tool.py +++ b/libs/community/langchain_community/tools/cassandra_database/tool.py @@ -1,6 +1,7 @@ """Tools for interacting with an Apache Cassandra database.""" from __future__ import annotations +import traceback from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union from langchain_core.callbacks import CallbackManagerForToolRun @@ -43,7 +44,11 @@ class QueryCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool): run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Union[str, Sequence[Dict[str, Any]], ResultSet]: """Execute the query, return the results or an error message.""" - return self.db.run_no_throw(query) + try: + return self.db.run(query) + except Exception as e: + """Format the error message""" + return f"Error: {e}\n{traceback.format_exc()}" class _GetSchemaCassandraDatabaseToolInput(BaseModel): @@ -73,7 +78,12 @@ class GetSchemaCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool): run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Get the schema for a keyspace.""" - return self.db.get_keyspace_tables_str_no_throw(keyspace) + try: + tables = self.db.get_keyspace_tables(keyspace) + return "".join([table.as_markdown() + "\n\n" for table in tables]) + except Exception as e: + """Format the error message""" + return f"Error: {e}\n{traceback.format_exc()}" class _GetTableDataCassandraDatabaseToolInput(BaseModel): @@ -123,4 +133,8 @@ class GetTableDataCassandraDatabaseTool(BaseCassandraDatabaseTool, BaseTool): run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Get data from a table in a keyspace.""" - return self.db.get_table_data_no_throw(keyspace, table, predicate, limit) + try: + return self.db.get_table_data(keyspace, table, predicate, limit) + except Exception as e: + """Format the error message""" + return f"Error: {e}\n{traceback.format_exc()}" diff --git a/libs/community/langchain_community/utilities/cassandra_database.py b/libs/community/langchain_community/utilities/cassandra_database.py index ccaae56018..6b607c4c43 100644 --- a/libs/community/langchain_community/utilities/cassandra_database.py +++ b/libs/community/langchain_community/utilities/cassandra_database.py @@ -2,7 +2,6 @@ from __future__ import annotations import re -import traceback from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union from langchain_core.pydantic_v1 import BaseModel, Field, root_validator @@ -32,9 +31,10 @@ class CassandraDatabase: include_tables: Optional[List[str]] = None, cassio_init_kwargs: Optional[Dict[str, Any]] = None, ): - self._session = self._resolve_session(session, cassio_init_kwargs) - if not self._session: + _session = self._resolve_session(session, cassio_init_kwargs) + if not _session: raise ValueError("Session not provided and cannot be resolved") + self._session = _session self._exclude_keyspaces = IGNORED_KEYSPACES self._exclude_tables = exclude_tables or [] @@ -44,52 +44,28 @@ class CassandraDatabase: self, query: str, fetch: str = "all", - include_columns: bool = False, **kwargs: Any, - ) -> Union[str, Sequence[Dict[str, Any]], ResultSet]: + ) -> Union[list, Dict[str, Any], ResultSet]: """Execute a CQL query and return the results.""" - clean_query = self._validate_cql(query, "SELECT") - result = self._session.execute(clean_query, **kwargs) if fetch == "all": - return list(result) + return self.fetch_all(query, **kwargs) elif fetch == "one": - return result.one()._asdict() if result else {} + return self.fetch_one(query, **kwargs) elif fetch == "cursor": - return result + return self._fetch(query, **kwargs) else: raise ValueError("Fetch parameter must be either 'one', 'all', or 'cursor'") - def run_no_throw( - self, - query: str, - fetch: str = "all", - include_columns: bool = False, - **kwargs: Any, - ) -> Union[str, Sequence[Dict[str, Any]], ResultSet]: - """Execute a CQL query and return the results or an error message.""" - try: - return self.run(query, fetch, include_columns, **kwargs) - except Exception as e: - """Format the error message""" - return f"Error: {e}\n{traceback.format_exc()}" - - def get_keyspace_tables_str_no_throw(self, keyspace: str) -> str: - """Get the tables for the specified keyspace.""" - try: - schema_string = self.get_keyspace_tables_str(keyspace) - return schema_string - except Exception as e: - """Format the error message""" - return f"Error: {e}\n{traceback.format_exc()}" + def _fetch(self, query: str, **kwargs: Any) -> ResultSet: + clean_query = self._validate_cql(query, "SELECT") + return self._session.execute(clean_query, **kwargs) - def get_keyspace_tables_str(self, keyspace: str) -> str: - """Get the tables for the specified keyspace.""" - tables = self.get_keyspace_tables(keyspace) - schema_string = "" - for table in tables: - schema_string += table.as_markdown() + "\n\n" + def fetch_all(self, query: str, **kwargs: Any) -> list: + return list(self._fetch(query, **kwargs)) - return schema_string + def fetch_one(self, query: str, **kwargs: Any) -> Dict[str, Any]: + result = self._fetch(query, **kwargs) + return result.one()._asdict() if result else {} def get_keyspace_tables(self, keyspace: str) -> List[Table]: """Get the Table objects for the specified keyspace.""" @@ -99,17 +75,6 @@ class CassandraDatabase: else: return [] - def get_table_data_no_throw( - self, keyspace: str, table: str, predicate: str, limit: int - ) -> str: - """Get data from the specified table in the specified keyspace. Optionally can - take a predicate for the WHERE clause and a limit.""" - try: - return self.get_table_data(keyspace, table, predicate, limit) - except Exception as e: - """Format the error message""" - return f"Error: {e}\n{traceback.format_exc()}" - # This is a more basic string building function that doesn't use a query builder # or prepared statements # TODO: Refactor to use prepared statements @@ -127,7 +92,7 @@ class CassandraDatabase: query += ";" - result = self.run(query, fetch="all") + result = self.fetch_all(query) data = "\n".join(str(row) for row in result) return data @@ -144,15 +109,13 @@ class CassandraDatabase: by iterating over all tables within that keyspace and calling their as_markdown method. - Parameters: - - keyspace (str): The name of the keyspace to generate markdown - documentation for. - - tables (list[Table]): list of tables in the keyspace; it will be resolved - if not provided. + Args: + keyspace: The name of the keyspace to generate markdown documentation for. + tables: list of tables in the keyspace; it will be resolved if not provided. Returns: - A string containing the markdown representation of the specified - keyspace schema. + A string containing the markdown representation of the specified + keyspace schema. """ if not tables: tables = self.get_keyspace_tables(keyspace) @@ -184,10 +147,10 @@ class CassandraDatabase: the subset of keyspaces that have been resolved in this instance. Returns: - A markdown string that documents the schema of all resolved keyspaces and - their tables within this CassandraDatabase instance. This includes keyspace - names, table names, comments, columns, partition keys, clustering keys, - and indexes for each table. + A markdown string that documents the schema of all resolved keyspaces and + their tables within this CassandraDatabase instance. This includes keyspace + names, table names, comments, columns, partition keys, clustering keys, + and indexes for each table. """ schema = self._resolve_schema() output = "# Cassandra Database Schema\n\n" @@ -201,18 +164,18 @@ class CassandraDatabase: Ensures that `cql` starts with the specified type (e.g., SELECT) and does not contain content that could indicate CQL injection vulnerabilities. - Parameters: - - cql (str): The CQL query string to be validated. - - type (str): The expected starting keyword of the query, used to verify - that the query begins with the correct operation type - (e.g., "SELECT", "UPDATE"). Defaults to "SELECT". + Args: + cql: The CQL query string to be validated. + type: The expected starting keyword of the query, used to verify + that the query begins with the correct operation type + (e.g., "SELECT", "UPDATE"). Defaults to "SELECT". Returns: - - str: The trimmed and validated CQL query string without a trailing semicolon. + The trimmed and validated CQL query string without a trailing semicolon. Raises: - - ValueError: If the value of `type` is not supported - - DatabaseError: If `cql` is considered unsafe + ValueError: If the value of `type` is not supported + DatabaseError: If `cql` is considered unsafe """ SUPPORTED_TYPES = ["SELECT"] if type and type.upper() not in SUPPORTED_TYPES: @@ -246,29 +209,26 @@ class CassandraDatabase: # The trimmed query, before modifications return cql_trimmed - def _fetch_keyspaces(self, keyspace_list: Optional[List[str]] = None) -> List[str]: + def _fetch_keyspaces(self, keyspaces: Optional[List[str]] = None) -> List[str]: """ Fetches a list of keyspace names from the Cassandra database. The list can be filtered by a provided list of keyspace names or by excluding predefined keyspaces. - Parameters: - - keyspace_list (Optional[List[str]]): A list of keyspace names to specifically - include. If provided and not empty, the method returns only the keyspaces - present in this list. If not provided or empty, the method returns all - keyspaces except those specified in the _exclude_keyspaces attribute. + Args: + keyspaces: A list of keyspace names to specifically include. + If provided and not empty, the method returns only the keyspaces + present in this list. + If not provided or empty, the method returns all keyspaces except those + specified in the _exclude_keyspaces attribute. Returns: - - List[str]: A list of keyspace names according to the filtering criteria. + A list of keyspace names according to the filtering criteria. """ - all_keyspaces = self.run( - "SELECT keyspace_name FROM system_schema.keyspaces", fetch="all" + all_keyspaces = self.fetch_all( + "SELECT keyspace_name FROM system_schema.keyspaces" ) - # Type check to ensure 'all_keyspaces' is a sequence of dictionaries - if not isinstance(all_keyspaces, Sequence): - raise TypeError("Expected a sequence of dictionaries from 'run' method.") - # Filtering keyspaces based on 'keyspace_list' and '_exclude_keyspaces' filtered_keyspaces = [] for ks in all_keyspaces: @@ -276,87 +236,105 @@ class CassandraDatabase: continue # Skip if the row is not a dictionary. keyspace_name = ks["keyspace_name"] - if keyspace_list and keyspace_name in keyspace_list: + if keyspaces and keyspace_name in keyspaces: filtered_keyspaces.append(keyspace_name) - elif not keyspace_list and keyspace_name not in self._exclude_keyspaces: + elif not keyspaces and keyspace_name not in self._exclude_keyspaces: filtered_keyspaces.append(keyspace_name) return filtered_keyspaces - def _fetch_schema_data(self, keyspace_list: List[str]) -> Tuple: - """ - Fetches schema data, including tables, columns, and indexes, filtered by a - list of keyspaces. This method constructs CQL queries to retrieve detailed - schema information from the specified keyspaces and executes them to gather - data about tables, columns, and indexes within those keyspaces. + def _format_keyspace_query(self, query: str, keyspaces: List[str]) -> str: + # Construct IN clause for CQL query + keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspaces]) + return f"""{query} WHERE keyspace_name IN ({keyspace_in_clause})""" - Parameters: - - keyspace_list (List[str]): A list of keyspace names from which to fetch - schema data. + def _fetch_tables_data(self, keyspaces: List[str]) -> list: + """Fetches tables schema data, filtered by a list of keyspaces. + This method allows for efficiently fetching schema information for multiple + keyspaces in a single operation, enabling applications to programmatically + analyze or document the database schema. - Returns: - - Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: A - tuple containing three lists: - - The first list contains dictionaries of table details (keyspace name, - table name, and comment). - - The second list contains dictionaries of column details (keyspace name, - table name, column name, type, kind, and position). - - The third list contains dictionaries of index details (keyspace name, - table name, index name, kind, and options). + Args: + keyspaces: A list of keyspace names from which to fetch tables schema data. - This method allows for efficiently fetching schema information for multiple - keyspaces in a single operation, - enabling applications to programmatically analyze or document the database - schema. + Returns: + Dictionaries of table details (keyspace name, table name, and comment). """ - # Construct IN clause for CQL query - keyspace_in_clause = ", ".join([f"'{ks}'" for ks in keyspace_list]) - - # Fetch filtered table details - tables_query = f"""SELECT keyspace_name, table_name, comment - FROM system_schema.tables - WHERE keyspace_name - IN ({keyspace_in_clause})""" + tables_query = self._format_keyspace_query( + "SELECT keyspace_name, table_name, comment FROM system_schema.tables", + keyspaces, + ) + return self.fetch_all(tables_query) - tables_data = self.run(tables_query, fetch="all") + def _fetch_columns_data(self, keyspaces: List[str]) -> list: + """Fetches columns schema data, filtered by a list of keyspaces. + This method allows for efficiently fetching schema information for multiple + keyspaces in a single operation, enabling applications to programmatically + analyze or document the database schema. - # Fetch filtered column details - columns_query = f"""SELECT keyspace_name, table_name, column_name, type, - kind, clustering_order, position - FROM system_schema.columns - WHERE keyspace_name - IN ({keyspace_in_clause})""" + Args: + keyspaces: A list of keyspace names from which to fetch tables schema data. - columns_data = self.run(columns_query, fetch="all") + Returns: + Dictionaries of column details (keyspace name, table name, column name, + type, kind, and position). + """ + tables_query = self._format_keyspace_query( + """ + SELECT keyspace_name, table_name, column_name, type, kind, + clustering_order, position + FROM system_schema.columns + """, + keyspaces, + ) + return self.fetch_all(tables_query) - # Fetch filtered index details - indexes_query = f"""SELECT keyspace_name, table_name, index_name, - kind, options - FROM system_schema.indexes - WHERE keyspace_name - IN ({keyspace_in_clause})""" + def _fetch_indexes_data(self, keyspaces: List[str]) -> list: + """Fetches indexes schema data, filtered by a list of keyspaces. + This method allows for efficiently fetching schema information for multiple + keyspaces in a single operation, enabling applications to programmatically + analyze or document the database schema. - indexes_data = self.run(indexes_query, fetch="all") + Args: + keyspaces: A list of keyspace names from which to fetch tables schema data. - return tables_data, columns_data, indexes_data + Returns: + Dictionaries of index details (keyspace name, table name, index name, kind, + and options). + """ + tables_query = self._format_keyspace_query( + """ + SELECT keyspace_name, table_name, index_name, + kind, options + FROM system_schema.indexes + """, + keyspaces, + ) + return self.fetch_all(tables_query) def _resolve_schema( - self, keyspace_list: Optional[List[str]] = None + self, keyspaces: Optional[List[str]] = None ) -> Dict[str, List[Table]]: """ Efficiently fetches and organizes Cassandra table schema information, such as comments, columns, and indexes, into a dictionary mapping keyspace names to lists of Table objects. + Args: + keyspaces: An optional list of keyspace names from which to fetch tables + schema data. + Returns: - A dictionary with keyspace names as keys and lists of Table objects as values, - where each Table object is populated with schema details appropriate for its - keyspace and table name. + A dictionary with keyspace names as keys and lists of Table objects as + values, where each Table object is populated with schema details + appropriate for its keyspace and table name. """ - if not keyspace_list: - keyspace_list = self._fetch_keyspaces() + if not keyspaces: + keyspaces = self._fetch_keyspaces() - tables_data, columns_data, indexes_data = self._fetch_schema_data(keyspace_list) + tables_data = self._fetch_tables_data(keyspaces) + columns_data = self._fetch_columns_data(keyspaces) + indexes_data = self._fetch_indexes_data(keyspaces) keyspace_dict: dict = {} for table_data in tables_data: @@ -415,11 +393,11 @@ class CassandraDatabase: return keyspace_dict + @staticmethod def _resolve_session( - self, session: Optional[Session] = None, cassio_init_kwargs: Optional[Dict[str, Any]] = None, - ) -> Session: + ) -> Optional[Session]: """ Attempts to resolve and return a Session object for use in database operations. @@ -430,18 +408,17 @@ class CassandraDatabase: 3. A new `cassio` session derived from `cassio_init_kwargs`, 4. `None` - Parameters: - - session (Optional[Session]): An optional session to use directly. - - cassio_init_kwargs (Optional[Dict[str, Any]]): An optional dictionary of - keyword arguments to `cassio`. + Args: + session: An optional session to use directly. + cassio_init_kwargs: An optional dictionary of keyword arguments to `cassio`. Returns: - - Session: The resolved session object if successful, or `None` if the session - cannot be resolved. + The resolved session object if successful, or `None` if the session + cannot be resolved. Raises: - - ValueError: If `cassio_init_kwargs` is provided but is not a dictionary of - keyword arguments. + ValueError: If `cassio_init_kwargs` is provided but is not a dictionary of + keyword arguments. """ # Prefer given session @@ -535,20 +512,18 @@ class Table(BaseModel): Generates a Markdown representation of the Cassandra table schema, allowing for customizable header levels for the table name section. - Parameters: - - include_keyspace (bool): If True, includes the keyspace in the output. - Defaults to True. - - header_level (Optional[int]): Specifies the markdown header level for the - table name. - If None, the table name is included without a header. Defaults to None - (no header level). + Args: + include_keyspace: If True, includes the keyspace in the output. + Defaults to True. + header_level: Specifies the markdown header level for the table name. + If None, the table name is included without a header. + Defaults to None (no header level). Returns: - - str: A string in Markdown format detailing the table name - (with optional header level), - keyspace (optional), comment, columns, partition keys, clustering keys - (with optional clustering order), - and indexes. + A string in Markdown format detailing the table name + (with optional header level), keyspace (optional), comment, columns, + partition keys, clustering keys (with optional clustering order), + and indexes. """ output = "" if header_level is not None: