mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
e54a49b697
For some DBs with lots of tables, reflection of all the tables can take very long. So this change will make the tables be reflected lazily when get_table_info() is called and `lazy_table_reflection` is True.
570 lines
22 KiB
Python
570 lines
22 KiB
Python
"""SQLAlchemy wrapper around a database."""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union
|
|
|
|
import sqlalchemy
|
|
from langchain_core._api import deprecated
|
|
from langchain_core.utils import get_from_env
|
|
from sqlalchemy import (
|
|
MetaData,
|
|
Table,
|
|
create_engine,
|
|
inspect,
|
|
select,
|
|
text,
|
|
)
|
|
from sqlalchemy.engine import Engine, Result
|
|
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
|
from sqlalchemy.schema import CreateTable
|
|
from sqlalchemy.sql.expression import Executable
|
|
from sqlalchemy.types import NullType
|
|
|
|
|
|
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
|
return (
|
|
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
|
f' Columns: {str(index["column_names"])}'
|
|
)
|
|
|
|
|
|
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str:
|
|
"""
|
|
Truncate a string to a certain number of words, based on the max string
|
|
length.
|
|
"""
|
|
|
|
if not isinstance(content, str) or length <= 0:
|
|
return content
|
|
|
|
if len(content) <= length:
|
|
return content
|
|
|
|
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
|
|
|
|
|
|
class SQLDatabase:
|
|
"""SQLAlchemy wrapper around a database."""
|
|
|
|
def __init__(
|
|
self,
|
|
engine: Engine,
|
|
schema: Optional[str] = None,
|
|
metadata: Optional[MetaData] = None,
|
|
ignore_tables: Optional[List[str]] = None,
|
|
include_tables: Optional[List[str]] = None,
|
|
sample_rows_in_table_info: int = 3,
|
|
indexes_in_table_info: bool = False,
|
|
custom_table_info: Optional[dict] = None,
|
|
view_support: bool = False,
|
|
max_string_length: int = 300,
|
|
lazy_table_reflection: bool = False,
|
|
):
|
|
"""Create engine from database URI."""
|
|
self._engine = engine
|
|
self._schema = schema
|
|
if include_tables and ignore_tables:
|
|
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
|
|
|
self._inspector = inspect(self._engine)
|
|
|
|
# including view support by adding the views as well as tables to the all
|
|
# tables list if view_support is True
|
|
self._all_tables = set(
|
|
self._inspector.get_table_names(schema=schema)
|
|
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
|
|
)
|
|
|
|
self._include_tables = set(include_tables) if include_tables else set()
|
|
if self._include_tables:
|
|
missing_tables = self._include_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"include_tables {missing_tables} not found in database"
|
|
)
|
|
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
|
|
if self._ignore_tables:
|
|
missing_tables = self._ignore_tables - self._all_tables
|
|
if missing_tables:
|
|
raise ValueError(
|
|
f"ignore_tables {missing_tables} not found in database"
|
|
)
|
|
usable_tables = self.get_usable_table_names()
|
|
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
|
|
|
|
if not isinstance(sample_rows_in_table_info, int):
|
|
raise TypeError("sample_rows_in_table_info must be an integer")
|
|
|
|
self._sample_rows_in_table_info = sample_rows_in_table_info
|
|
self._indexes_in_table_info = indexes_in_table_info
|
|
|
|
self._custom_table_info = custom_table_info
|
|
if self._custom_table_info:
|
|
if not isinstance(self._custom_table_info, dict):
|
|
raise TypeError(
|
|
"table_info must be a dictionary with table names as keys and the "
|
|
"desired table info as values"
|
|
)
|
|
# only keep the tables that are also present in the database
|
|
intersection = set(self._custom_table_info).intersection(self._all_tables)
|
|
self._custom_table_info = dict(
|
|
(table, self._custom_table_info[table])
|
|
for table in self._custom_table_info
|
|
if table in intersection
|
|
)
|
|
|
|
self._max_string_length = max_string_length
|
|
self._view_support = view_support
|
|
|
|
self._metadata = metadata or MetaData()
|
|
if not lazy_table_reflection:
|
|
# including view support if view_support = true
|
|
self._metadata.reflect(
|
|
views=view_support,
|
|
bind=self._engine,
|
|
only=list(self._usable_tables),
|
|
schema=self._schema,
|
|
)
|
|
|
|
@classmethod
|
|
def from_uri(
|
|
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
|
) -> SQLDatabase:
|
|
"""Construct a SQLAlchemy engine from URI."""
|
|
_engine_args = engine_args or {}
|
|
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
|
|
|
@classmethod
|
|
def from_databricks(
|
|
cls,
|
|
catalog: str,
|
|
schema: str,
|
|
host: Optional[str] = None,
|
|
api_token: Optional[str] = None,
|
|
warehouse_id: Optional[str] = None,
|
|
cluster_id: Optional[str] = None,
|
|
engine_args: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> SQLDatabase:
|
|
"""
|
|
Class method to create an SQLDatabase instance from a Databricks connection.
|
|
This method requires the 'databricks-sql-connector' package. If not installed,
|
|
it can be added using `pip install databricks-sql-connector`.
|
|
|
|
Args:
|
|
catalog (str): The catalog name in the Databricks database.
|
|
schema (str): The schema name in the catalog.
|
|
host (Optional[str]): The Databricks workspace hostname, excluding
|
|
'https://' part. If not provided, it attempts to fetch from the
|
|
environment variable 'DATABRICKS_HOST'. If still unavailable and if
|
|
running in a Databricks notebook, it defaults to the current workspace
|
|
hostname. Defaults to None.
|
|
api_token (Optional[str]): The Databricks personal access token for
|
|
accessing the Databricks SQL warehouse or the cluster. If not provided,
|
|
it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable
|
|
and running in a Databricks notebook, a temporary token for the current
|
|
user is generated. Defaults to None.
|
|
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If
|
|
provided, the method configures the connection to use this warehouse.
|
|
Cannot be used with 'cluster_id'. Defaults to None.
|
|
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If
|
|
provided, the method configures the connection to use this cluster.
|
|
Cannot be used with 'warehouse_id'. If running in a Databricks notebook
|
|
and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the
|
|
cluster the notebook is attached to. Defaults to None.
|
|
engine_args (Optional[dict]): The arguments to be used when connecting
|
|
Databricks. Defaults to None.
|
|
**kwargs (Any): Additional keyword arguments for the `from_uri` method.
|
|
|
|
Returns:
|
|
SQLDatabase: An instance of SQLDatabase configured with the provided
|
|
Databricks connection details.
|
|
|
|
Raises:
|
|
ValueError: If 'databricks-sql-connector' is not found, or if both
|
|
'warehouse_id' and 'cluster_id' are provided, or if neither
|
|
'warehouse_id' nor 'cluster_id' are provided and it's not executing
|
|
inside a Databricks notebook.
|
|
"""
|
|
try:
|
|
from databricks import sql # noqa: F401
|
|
except ImportError:
|
|
raise ValueError(
|
|
"databricks-sql-connector package not found, please install with"
|
|
" `pip install databricks-sql-connector`"
|
|
)
|
|
context = None
|
|
try:
|
|
from dbruntime.databricks_repl_context import get_context
|
|
|
|
context = get_context()
|
|
except ImportError:
|
|
pass
|
|
|
|
default_host = context.browserHostName if context else None
|
|
if host is None:
|
|
host = get_from_env("host", "DATABRICKS_HOST", default_host)
|
|
|
|
default_api_token = context.apiToken if context else None
|
|
if api_token is None:
|
|
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token)
|
|
|
|
if warehouse_id is None and cluster_id is None:
|
|
if context:
|
|
cluster_id = context.clusterId
|
|
else:
|
|
raise ValueError(
|
|
"Need to provide either 'warehouse_id' or 'cluster_id'."
|
|
)
|
|
|
|
if warehouse_id and cluster_id:
|
|
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.")
|
|
|
|
if warehouse_id:
|
|
http_path = f"/sql/1.0/warehouses/{warehouse_id}"
|
|
else:
|
|
http_path = f"/sql/protocolv1/o/0/{cluster_id}"
|
|
|
|
uri = (
|
|
f"databricks://token:{api_token}@{host}?"
|
|
f"http_path={http_path}&catalog={catalog}&schema={schema}"
|
|
)
|
|
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs)
|
|
|
|
@classmethod
|
|
def from_cnosdb(
|
|
cls,
|
|
url: str = "127.0.0.1:8902",
|
|
user: str = "root",
|
|
password: str = "",
|
|
tenant: str = "cnosdb",
|
|
database: str = "public",
|
|
) -> SQLDatabase:
|
|
"""
|
|
Class method to create an SQLDatabase instance from a CnosDB connection.
|
|
This method requires the 'cnos-connector' package. If not installed, it
|
|
can be added using `pip install cnos-connector`.
|
|
|
|
Args:
|
|
url (str): The HTTP connection host name and port number of the CnosDB
|
|
service, excluding "http://" or "https://", with a default value
|
|
of "127.0.0.1:8902".
|
|
user (str): The username used to connect to the CnosDB service, with a
|
|
default value of "root".
|
|
password (str): The password of the user connecting to the CnosDB service,
|
|
with a default value of "".
|
|
tenant (str): The name of the tenant used to connect to the CnosDB service,
|
|
with a default value of "cnosdb".
|
|
database (str): The name of the database in the CnosDB tenant.
|
|
|
|
Returns:
|
|
SQLDatabase: An instance of SQLDatabase configured with the provided
|
|
CnosDB connection details.
|
|
"""
|
|
try:
|
|
from cnosdb_connector import make_cnosdb_langchain_uri
|
|
|
|
uri = make_cnosdb_langchain_uri(url, user, password, tenant, database)
|
|
return cls.from_uri(database_uri=uri)
|
|
except ImportError:
|
|
raise ValueError(
|
|
"cnos-connector package not found, please install with"
|
|
" `pip install cnos-connector`"
|
|
)
|
|
|
|
@property
|
|
def dialect(self) -> str:
|
|
"""Return string representation of dialect to use."""
|
|
return self._engine.dialect.name
|
|
|
|
def get_usable_table_names(self) -> Iterable[str]:
|
|
"""Get names of tables available."""
|
|
if self._include_tables:
|
|
return sorted(self._include_tables)
|
|
return sorted(self._all_tables - self._ignore_tables)
|
|
|
|
@deprecated("0.0.1", alternative="get_usable_table_name", removal="0.2.0")
|
|
def get_table_names(self) -> Iterable[str]:
|
|
"""Get names of tables available."""
|
|
return self.get_usable_table_names()
|
|
|
|
@property
|
|
def table_info(self) -> str:
|
|
"""Information about all tables in the database."""
|
|
return self.get_table_info()
|
|
|
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
"""Get information about specified tables.
|
|
|
|
Follows best practices as specified in: Rajkumar et al, 2022
|
|
(https://arxiv.org/abs/2204.00498)
|
|
|
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
appended to each table description. This can increase performance as
|
|
demonstrated in the paper.
|
|
"""
|
|
all_table_names = self.get_usable_table_names()
|
|
if table_names is not None:
|
|
missing_tables = set(table_names).difference(all_table_names)
|
|
if missing_tables:
|
|
raise ValueError(f"table_names {missing_tables} not found in database")
|
|
all_table_names = table_names
|
|
|
|
metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]
|
|
to_reflect = set(all_table_names) - set(metadata_table_names)
|
|
if to_reflect:
|
|
self._metadata.reflect(
|
|
views=self._view_support,
|
|
bind=self._engine,
|
|
only=list(to_reflect),
|
|
schema=self._schema,
|
|
)
|
|
|
|
meta_tables = [
|
|
tbl
|
|
for tbl in self._metadata.sorted_tables
|
|
if tbl.name in set(all_table_names)
|
|
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
|
]
|
|
|
|
tables = []
|
|
for table in meta_tables:
|
|
if self._custom_table_info and table.name in self._custom_table_info:
|
|
tables.append(self._custom_table_info[table.name])
|
|
continue
|
|
|
|
# Ignore JSON datatyped columns
|
|
for k, v in table.columns.items():
|
|
if type(v.type) is NullType:
|
|
table._columns.remove(v)
|
|
|
|
# add create table command
|
|
create_table = str(CreateTable(table).compile(self._engine))
|
|
table_info = f"{create_table.rstrip()}"
|
|
has_extra_info = (
|
|
self._indexes_in_table_info or self._sample_rows_in_table_info
|
|
)
|
|
if has_extra_info:
|
|
table_info += "\n\n/*"
|
|
if self._indexes_in_table_info:
|
|
table_info += f"\n{self._get_table_indexes(table)}\n"
|
|
if self._sample_rows_in_table_info:
|
|
table_info += f"\n{self._get_sample_rows(table)}\n"
|
|
if has_extra_info:
|
|
table_info += "*/"
|
|
tables.append(table_info)
|
|
tables.sort()
|
|
final_str = "\n\n".join(tables)
|
|
return final_str
|
|
|
|
def _get_table_indexes(self, table: Table) -> str:
|
|
indexes = self._inspector.get_indexes(table.name)
|
|
indexes_formatted = "\n".join(map(_format_index, indexes))
|
|
return f"Table Indexes:\n{indexes_formatted}"
|
|
|
|
def _get_sample_rows(self, table: Table) -> str:
|
|
# build the select command
|
|
command = select(table).limit(self._sample_rows_in_table_info)
|
|
|
|
# save the columns in string format
|
|
columns_str = "\t".join([col.name for col in table.columns])
|
|
|
|
try:
|
|
# get the sample rows
|
|
with self._engine.connect() as connection:
|
|
sample_rows_result = connection.execute(command) # type: ignore
|
|
# shorten values in the sample rows
|
|
sample_rows = list(
|
|
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
|
)
|
|
|
|
# save the sample rows in string format
|
|
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
|
|
|
# in some dialects when there are no rows in the table a
|
|
# 'ProgrammingError' is returned
|
|
except ProgrammingError:
|
|
sample_rows_str = ""
|
|
|
|
return (
|
|
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
|
f"{columns_str}\n"
|
|
f"{sample_rows_str}"
|
|
)
|
|
|
|
def _execute(
|
|
self,
|
|
command: Union[str, Executable],
|
|
fetch: Literal["all", "one", "cursor"] = "all",
|
|
*,
|
|
parameters: Optional[Dict[str, Any]] = None,
|
|
execution_options: Optional[Dict[str, Any]] = None,
|
|
) -> Union[Sequence[Dict[str, Any]], Result]:
|
|
"""
|
|
Executes SQL command through underlying engine.
|
|
|
|
If the statement returns no rows, an empty list is returned.
|
|
"""
|
|
parameters = parameters or {}
|
|
execution_options = execution_options or {}
|
|
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
|
|
if self._schema is not None:
|
|
if self.dialect == "snowflake":
|
|
connection.exec_driver_sql(
|
|
"ALTER SESSION SET search_path = %s",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "bigquery":
|
|
connection.exec_driver_sql(
|
|
"SET @@dataset_id=?",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "mssql":
|
|
pass
|
|
elif self.dialect == "trino":
|
|
connection.exec_driver_sql(
|
|
"USE ?",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "duckdb":
|
|
# Unclear which parameterized argument syntax duckdb supports.
|
|
# The docs for the duckdb client say they support multiple,
|
|
# but `duckdb_engine` seemed to struggle with all of them:
|
|
# https://github.com/Mause/duckdb_engine/issues/796
|
|
connection.exec_driver_sql(
|
|
f"SET search_path TO {self._schema}",
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "oracle":
|
|
connection.exec_driver_sql(
|
|
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}",
|
|
execution_options=execution_options,
|
|
)
|
|
elif self.dialect == "sqlany":
|
|
# If anybody using Sybase SQL anywhere database then it should not
|
|
# go to else condition. It should be same as mssql.
|
|
pass
|
|
elif self.dialect == "postgresql": # postgresql
|
|
connection.exec_driver_sql(
|
|
"SET search_path TO %s",
|
|
(self._schema,),
|
|
execution_options=execution_options,
|
|
)
|
|
|
|
if isinstance(command, str):
|
|
command = text(command)
|
|
elif isinstance(command, Executable):
|
|
pass
|
|
else:
|
|
raise TypeError(f"Query expression has unknown type: {type(command)}")
|
|
cursor = connection.execute(
|
|
command,
|
|
parameters,
|
|
execution_options=execution_options,
|
|
)
|
|
|
|
if cursor.returns_rows:
|
|
if fetch == "all":
|
|
result = [x._asdict() for x in cursor.fetchall()]
|
|
elif fetch == "one":
|
|
first_result = cursor.fetchone()
|
|
result = [] if first_result is None else [first_result._asdict()]
|
|
elif fetch == "cursor":
|
|
return cursor
|
|
else:
|
|
raise ValueError(
|
|
"Fetch parameter must be either 'one', 'all', or 'cursor'"
|
|
)
|
|
return result
|
|
return []
|
|
|
|
def run(
|
|
self,
|
|
command: Union[str, Executable],
|
|
fetch: Literal["all", "one", "cursor"] = "all",
|
|
include_columns: bool = False,
|
|
*,
|
|
parameters: Optional[Dict[str, Any]] = None,
|
|
execution_options: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
|
"""Execute a SQL command and return a string representing the results.
|
|
|
|
If the statement returns rows, a string of the results is returned.
|
|
If the statement returns no rows, an empty string is returned.
|
|
"""
|
|
result = self._execute(
|
|
command, fetch, parameters=parameters, execution_options=execution_options
|
|
)
|
|
|
|
if fetch == "cursor":
|
|
return result
|
|
|
|
res = [
|
|
{
|
|
column: truncate_word(value, length=self._max_string_length)
|
|
for column, value in r.items()
|
|
}
|
|
for r in result
|
|
]
|
|
|
|
if not include_columns:
|
|
res = [tuple(row.values()) for row in res] # type: ignore[misc]
|
|
|
|
if not res:
|
|
return ""
|
|
else:
|
|
return str(res)
|
|
|
|
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
|
"""Get information about specified tables.
|
|
|
|
Follows best practices as specified in: Rajkumar et al, 2022
|
|
(https://arxiv.org/abs/2204.00498)
|
|
|
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
appended to each table description. This can increase performance as
|
|
demonstrated in the paper.
|
|
"""
|
|
try:
|
|
return self.get_table_info(table_names)
|
|
except ValueError as e:
|
|
"""Format the error message"""
|
|
return f"Error: {e}"
|
|
|
|
def run_no_throw(
|
|
self,
|
|
command: str,
|
|
fetch: Literal["all", "one"] = "all",
|
|
include_columns: bool = False,
|
|
*,
|
|
parameters: Optional[Dict[str, Any]] = None,
|
|
execution_options: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Sequence[Dict[str, Any]], Result[Any]]:
|
|
"""Execute a SQL command and return a string representing the results.
|
|
|
|
If the statement returns rows, a string of the results is returned.
|
|
If the statement returns no rows, an empty string is returned.
|
|
|
|
If the statement throws an error, the error message is returned.
|
|
"""
|
|
try:
|
|
return self.run(
|
|
command,
|
|
fetch,
|
|
parameters=parameters,
|
|
execution_options=execution_options,
|
|
include_columns=include_columns,
|
|
)
|
|
except SQLAlchemyError as e:
|
|
"""Format the error message"""
|
|
return f"Error: {e}"
|
|
|
|
def get_context(self) -> Dict[str, Any]:
|
|
"""Return db context that you may want in agent prompt."""
|
|
table_names = list(self.get_usable_table_names())
|
|
table_info = self.get_table_info_no_throw()
|
|
return {"table_info": table_info, "table_names": ", ".join(table_names)}
|