from __future__ import annotations from typing import TYPE_CHECKING, Any, Iterable, List, Optional if TYPE_CHECKING: from pyspark.sql import DataFrame, Row, SparkSession class SparkSQL: """SparkSQL is a utility class for interacting with Spark SQL.""" def __init__( self, spark_session: Optional[SparkSession] = None, catalog: Optional[str] = None, schema: Optional[str] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, ): """Initialize a SparkSQL object. Args: spark_session: A SparkSession object. If not provided, one will be created. catalog: The catalog to use. If not provided, the default catalog will be used. schema: The schema to use. If not provided, the default schema will be used. ignore_tables: A list of tables to ignore. If not provided, all tables will be used. include_tables: A list of tables to include. If not provided, all tables will be used. sample_rows_in_table_info: The number of rows to include in the table info. Defaults to 3. """ try: from pyspark.sql import SparkSession except ImportError: raise ImportError( "pyspark is not installed. Please install it with `pip install pyspark`" ) self._spark = ( spark_session if spark_session else SparkSession.builder.getOrCreate() ) if catalog is not None: self._spark.catalog.setCurrentCatalog(catalog) if schema is not None: self._spark.catalog.setCurrentDatabase(schema) self._all_tables = set(self._get_all_table_names()) self._include_tables = set(include_tables) if include_tables else set() if self._include_tables: missing_tables = self._include_tables - self._all_tables if missing_tables: raise ValueError( f"include_tables {missing_tables} not found in database" ) self._ignore_tables = set(ignore_tables) if ignore_tables else set() if self._ignore_tables: missing_tables = self._ignore_tables - self._all_tables if missing_tables: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) usable_tables = self.get_usable_table_names() self._usable_tables = set(usable_tables) if usable_tables else self._all_tables if not isinstance(sample_rows_in_table_info, int): raise TypeError("sample_rows_in_table_info must be an integer") self._sample_rows_in_table_info = sample_rows_in_table_info @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> SparkSQL: """Creating a remote Spark Session via Spark connect. For example: SparkSQL.from_uri("sc://localhost:15002") """ try: from pyspark.sql import SparkSession except ImportError: raise ValueError( "pyspark is not installed. Please install it with `pip install pyspark`" ) spark = SparkSession.builder.remote(database_uri).getOrCreate() return cls(spark, **kwargs) def get_usable_table_names(self) -> Iterable[str]: """Get names of tables available.""" if self._include_tables: return self._include_tables # sorting the result can help LLM understanding it. return sorted(self._all_tables - self._ignore_tables) def _get_all_table_names(self) -> Iterable[str]: rows = self._spark.sql("SHOW TABLES").select("tableName").collect() return list(map(lambda row: row.tableName, rows)) def _get_create_table_stmt(self, table: str) -> str: statement = ( self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt ) # Ignore the data source provider and options to reduce the number of tokens. using_clause_index = statement.find("USING") return statement[:using_clause_index] + ";" def get_table_info(self, table_names: Optional[List[str]] = None) -> str: all_table_names = self.get_usable_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) if missing_tables: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names tables = [] for table_name in all_table_names: table_info = self._get_create_table_stmt(table_name) if self._sample_rows_in_table_info: table_info += "\n\n/*" table_info += f"\n{self._get_sample_spark_rows(table_name)}\n" table_info += "*/" tables.append(table_info) final_str = "\n\n".join(tables) return final_str def _get_sample_spark_rows(self, table: str) -> str: query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}" df = self._spark.sql(query) columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields))) try: sample_rows = self._get_dataframe_results(df) # save the sample rows in string format sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) except Exception: sample_rows_str = "" return ( f"{self._sample_rows_in_table_info} rows from {table} table:\n" f"{columns_str}\n" f"{sample_rows_str}" ) def _convert_row_as_tuple(self, row: Row) -> tuple: return tuple(map(str, row.asDict().values())) def _get_dataframe_results(self, df: DataFrame) -> list: return list(map(self._convert_row_as_tuple, df.collect())) def run(self, command: str, fetch: str = "all") -> str: df = self._spark.sql(command) if fetch == "one": df = df.limit(1) return str(self._get_dataframe_results(df)) def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498) If `sample_rows_in_table_info`, the specified number of sample rows will be appended to each table description. This can increase performance as demonstrated in the paper. """ try: return self.get_table_info(table_names) except ValueError as e: """Format the error message""" return f"Error: {e}" def run_no_throw(self, command: str, fetch: str = "all") -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. If the statement throws an error, the error message is returned. """ try: return self.run(command, fetch) except Exception as e: """Format the error message""" return f"Error: {e}"