"""SQLAlchemy wrapper around a database.""" from __future__ import annotations import warnings from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence import sqlalchemy from langchain_core.utils import get_from_env from sqlalchemy import MetaData, Table, create_engine, inspect, select, text from sqlalchemy.engine import Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable from sqlalchemy.types import NullType def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: return ( f'Name: {index["name"]}, Unique: {index["unique"]},' f' Columns: {str(index["column_names"])}' ) def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: """ Truncate a string to a certain number of words, based on the max string length. """ if not isinstance(content, str) or length <= 0: return content if len(content) <= length: return content return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix class SQLDatabase: """SQLAlchemy wrapper around a database.""" def __init__( self, engine: Engine, schema: Optional[str] = None, metadata: Optional[MetaData] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, sample_rows_in_table_info: int = 3, indexes_in_table_info: bool = False, custom_table_info: Optional[dict] = None, view_support: bool = False, max_string_length: int = 300, ): """Create engine from database URI.""" self._engine = engine self._schema = schema if include_tables and ignore_tables: raise ValueError("Cannot specify both include_tables and ignore_tables") self._inspector = inspect(self._engine) # including view support by adding the views as well as tables to the all # tables list if view_support is True self._all_tables = set( self._inspector.get_table_names(schema=schema) + (self._inspector.get_view_names(schema=schema) if view_support else []) ) self._include_tables = set(include_tables) if include_tables else set() if self._include_tables: missing_tables = self._include_tables - self._all_tables if missing_tables: raise ValueError( f"include_tables {missing_tables} not found in database" ) self._ignore_tables = set(ignore_tables) if ignore_tables else set() if self._ignore_tables: missing_tables = self._ignore_tables - self._all_tables if missing_tables: raise ValueError( f"ignore_tables {missing_tables} not found in database" ) usable_tables = self.get_usable_table_names() self._usable_tables = set(usable_tables) if usable_tables else self._all_tables if not isinstance(sample_rows_in_table_info, int): raise TypeError("sample_rows_in_table_info must be an integer") self._sample_rows_in_table_info = sample_rows_in_table_info self._indexes_in_table_info = indexes_in_table_info self._custom_table_info = custom_table_info if self._custom_table_info: if not isinstance(self._custom_table_info, dict): raise TypeError( "table_info must be a dictionary with table names as keys and the " "desired table info as values" ) # only keep the tables that are also present in the database intersection = set(self._custom_table_info).intersection(self._all_tables) self._custom_table_info = dict( (table, self._custom_table_info[table]) for table in self._custom_table_info if table in intersection ) self._max_string_length = max_string_length self._metadata = metadata or MetaData() # including view support if view_support = true self._metadata.reflect( views=view_support, bind=self._engine, only=list(self._usable_tables), schema=self._schema, ) @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> SQLDatabase: """Construct a SQLAlchemy engine from URI.""" _engine_args = engine_args or {} return cls(create_engine(database_uri, **_engine_args), **kwargs) @classmethod def from_databricks( cls, catalog: str, schema: str, host: Optional[str] = None, api_token: Optional[str] = None, warehouse_id: Optional[str] = None, cluster_id: Optional[str] = None, engine_args: Optional[dict] = None, **kwargs: Any, ) -> SQLDatabase: """ Class method to create an SQLDatabase instance from a Databricks connection. This method requires the 'databricks-sql-connector' package. If not installed, it can be added using `pip install databricks-sql-connector`. Args: catalog (str): The catalog name in the Databricks database. schema (str): The schema name in the catalog. host (Optional[str]): The Databricks workspace hostname, excluding 'https://' part. If not provided, it attempts to fetch from the environment variable 'DATABRICKS_HOST'. If still unavailable and if running in a Databricks notebook, it defaults to the current workspace hostname. Defaults to None. api_token (Optional[str]): The Databricks personal access token for accessing the Databricks SQL warehouse or the cluster. If not provided, it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable and running in a Databricks notebook, a temporary token for the current user is generated. Defaults to None. warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If provided, the method configures the connection to use this warehouse. Cannot be used with 'cluster_id'. Defaults to None. cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If provided, the method configures the connection to use this cluster. Cannot be used with 'warehouse_id'. If running in a Databricks notebook and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the cluster the notebook is attached to. Defaults to None. engine_args (Optional[dict]): The arguments to be used when connecting Databricks. Defaults to None. **kwargs (Any): Additional keyword arguments for the `from_uri` method. Returns: SQLDatabase: An instance of SQLDatabase configured with the provided Databricks connection details. Raises: ValueError: If 'databricks-sql-connector' is not found, or if both 'warehouse_id' and 'cluster_id' are provided, or if neither 'warehouse_id' nor 'cluster_id' are provided and it's not executing inside a Databricks notebook. """ try: from databricks import sql # noqa: F401 except ImportError: raise ValueError( "databricks-sql-connector package not found, please install with" " `pip install databricks-sql-connector`" ) context = None try: from dbruntime.databricks_repl_context import get_context context = get_context() except ImportError: pass default_host = context.browserHostName if context else None if host is None: host = get_from_env("host", "DATABRICKS_HOST", default_host) default_api_token = context.apiToken if context else None if api_token is None: api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token) if warehouse_id is None and cluster_id is None: if context: cluster_id = context.clusterId else: raise ValueError( "Need to provide either 'warehouse_id' or 'cluster_id'." ) if warehouse_id and cluster_id: raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.") if warehouse_id: http_path = f"/sql/1.0/warehouses/{warehouse_id}" else: http_path = f"/sql/protocolv1/o/0/{cluster_id}" uri = ( f"databricks://token:{api_token}@{host}?" f"http_path={http_path}&catalog={catalog}&schema={schema}" ) return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs) @classmethod def from_cnosdb( cls, url: str = "127.0.0.1:8902", user: str = "root", password: str = "", tenant: str = "cnosdb", database: str = "public", ) -> SQLDatabase: """ Class method to create an SQLDatabase instance from a CnosDB connection. This method requires the 'cnos-connector' package. If not installed, it can be added using `pip install cnos-connector`. Args: url (str): The HTTP connection host name and port number of the CnosDB service, excluding "http://" or "https://", with a default value of "127.0.0.1:8902". user (str): The username used to connect to the CnosDB service, with a default value of "root". password (str): The password of the user connecting to the CnosDB service, with a default value of "". tenant (str): The name of the tenant used to connect to the CnosDB service, with a default value of "cnosdb". database (str): The name of the database in the CnosDB tenant. Returns: SQLDatabase: An instance of SQLDatabase configured with the provided CnosDB connection details. """ try: from cnosdb_connector import make_cnosdb_langchain_uri uri = make_cnosdb_langchain_uri(url, user, password, tenant, database) return cls.from_uri(database_uri=uri) except ImportError: raise ValueError( "cnos-connector package not found, please install with" " `pip install cnos-connector`" ) @property def dialect(self) -> str: """Return string representation of dialect to use.""" return self._engine.dialect.name def get_usable_table_names(self) -> Iterable[str]: """Get names of tables available.""" if self._include_tables: return sorted(self._include_tables) return sorted(self._all_tables - self._ignore_tables) def get_table_names(self) -> Iterable[str]: """Get names of tables available.""" warnings.warn( "This method is deprecated - please use `get_usable_table_names`." ) return self.get_usable_table_names() @property def table_info(self) -> str: """Information about all tables in the database.""" return self.get_table_info() def get_table_info(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498) If `sample_rows_in_table_info`, the specified number of sample rows will be appended to each table description. This can increase performance as demonstrated in the paper. """ all_table_names = self.get_usable_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) if missing_tables: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names meta_tables = [ tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) ] tables = [] for table in meta_tables: if self._custom_table_info and table.name in self._custom_table_info: tables.append(self._custom_table_info[table.name]) continue # Ignore JSON datatyped columns for k, v in table.columns.items(): if type(v.type) is NullType: table._columns.remove(v) # add create table command create_table = str(CreateTable(table).compile(self._engine)) table_info = f"{create_table.rstrip()}" has_extra_info = ( self._indexes_in_table_info or self._sample_rows_in_table_info ) if has_extra_info: table_info += "\n\n/*" if self._indexes_in_table_info: table_info += f"\n{self._get_table_indexes(table)}\n" if self._sample_rows_in_table_info: table_info += f"\n{self._get_sample_rows(table)}\n" if has_extra_info: table_info += "*/" tables.append(table_info) tables.sort() final_str = "\n\n".join(tables) return final_str def _get_table_indexes(self, table: Table) -> str: indexes = self._inspector.get_indexes(table.name) indexes_formatted = "\n".join(map(_format_index, indexes)) return f"Table Indexes:\n{indexes_formatted}" def _get_sample_rows(self, table: Table) -> str: # build the select command command = select(table).limit(self._sample_rows_in_table_info) # save the columns in string format columns_str = "\t".join([col.name for col in table.columns]) try: # get the sample rows with self._engine.connect() as connection: sample_rows_result = connection.execute(command) # type: ignore # shorten values in the sample rows sample_rows = list( map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) ) # save the sample rows in string format sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) # in some dialects when there are no rows in the table a # 'ProgrammingError' is returned except ProgrammingError: sample_rows_str = "" return ( f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" f"{columns_str}\n" f"{sample_rows_str}" ) def _execute( self, command: str, fetch: Literal["all", "one"] = "all", ) -> Sequence[Dict[str, Any]]: """ Executes SQL command through underlying engine. If the statement returns no rows, an empty list is returned. """ with self._engine.begin() as connection: # type: Connection if self._schema is not None: if self.dialect == "snowflake": connection.exec_driver_sql( "ALTER SESSION SET search_path = %s", (self._schema,) ) elif self.dialect == "bigquery": connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,)) elif self.dialect == "mssql": pass elif self.dialect == "trino": connection.exec_driver_sql("USE ?", (self._schema,)) elif self.dialect == "duckdb": # Unclear which parameterized argument syntax duckdb supports. # The docs for the duckdb client say they support multiple, # but `duckdb_engine` seemed to struggle with all of them: # https://github.com/Mause/duckdb_engine/issues/796 connection.exec_driver_sql(f"SET search_path TO {self._schema}") elif self.dialect == "oracle": connection.exec_driver_sql( f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}" ) elif self.dialect == "sqlany": # If anybody using Sybase SQL anywhere database then it should not # go to else condition. It should be same as mssql. pass else: # postgresql and other compatible dialects connection.exec_driver_sql("SET search_path TO %s", (self._schema,)) cursor = connection.execute(text(command)) if cursor.returns_rows: if fetch == "all": result = [x._asdict() for x in cursor.fetchall()] elif fetch == "one": first_result = cursor.fetchone() result = [] if first_result is None else [first_result._asdict()] else: raise ValueError("Fetch parameter must be either 'one' or 'all'") return result return [] def run( self, command: str, fetch: Literal["all", "one"] = "all", include_columns: bool = False, ) -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. """ result = self._execute(command, fetch) res = [ { column: truncate_word(value, length=self._max_string_length) for column, value in r.items() } for r in result ] if not include_columns: res = [tuple(row.values()) for row in res] if not res: return "" else: return str(res) def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498) If `sample_rows_in_table_info`, the specified number of sample rows will be appended to each table description. This can increase performance as demonstrated in the paper. """ try: return self.get_table_info(table_names) except ValueError as e: """Format the error message""" return f"Error: {e}" def run_no_throw( self, command: str, fetch: Literal["all", "one"] = "all", include_columns: bool = False, ) -> str: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. If the statement throws an error, the error message is returned. """ try: return self.run(command, fetch, include_columns) except SQLAlchemyError as e: """Format the error message""" return f"Error: {e}"