diff --git a/docs/docs/integrations/tools/databricks.ipynb b/docs/docs/integrations/tools/databricks.ipynb new file mode 100644 index 0000000000..04bfc1f216 --- /dev/null +++ b/docs/docs/integrations/tools/databricks.ipynb @@ -0,0 +1,168 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Databricks Unity Catalog (UC)\n", + "\n", + "This notebook shows how to use UC functions as LangChain tools.\n", + "\n", + "See Databricks documentation ([AWS](https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)|[Azure](https://learn.microsoft.com/en-us/azure/databricks/sql/language-manual/sql-ref-syntax-ddl-create-sql-function)|[GCP](https://docs.gcp.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-sql-function.html)) to learn how to create SQL or Python functions in UC. Do not skip function and parameter comments, which are critical for LLMs to call functions properly.\n", + "\n", + "In this example notebook, we create a simple Python function that executes arbitary code and use it as a LangChain tool:\n", + "\n", + "```sql\n", + "CREATE FUNCTION main.tools.python_exec (\n", + " code STRING COMMENT 'Python code to execute. Remember to print the final result to stdout.'\n", + ")\n", + "RETURNS STRING\n", + "LANGUAGE PYTHON\n", + "COMMENT 'Executes Python code and returns its stdout.'\n", + "AS $$\n", + " import sys\n", + " from io import StringIO\n", + " stdout = StringIO()\n", + " sys.stdout = stdout\n", + " exec(code)\n", + " return stdout.getvalue()\n", + "$$\n", + "```\n", + "\n", + "It runs in a secure and isolated environment within a Databricks SQL warehouse." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet databricks-sdk langchain-community langchain-openai" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_openai import ChatOpenAI\n", + "\n", + "llm = ChatOpenAI(model=\"gpt-3.5-turbo\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.tools.databricks import UCFunctionToolkit\n", + "\n", + "tools = (\n", + " UCFunctionToolkit(\n", + " # You can find the SQL warehouse ID in its UI after creation.\n", + " warehouse_id=\"xxxx123456789\"\n", + " )\n", + " .include(\n", + " # Include functions as tools using their qualified names.\n", + " # You can use \"{catalog_name}.{schema_name}.*\" to get all functions in a schema.\n", + " \"main.tools.python_exec\",\n", + " )\n", + " .get_tools()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import AgentExecutor, create_tool_calling_agent\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant. Make sure to use tool for information.\",\n", + " ),\n", + " (\"placeholder\", \"{chat_history}\"),\n", + " (\"human\", \"{input}\"),\n", + " (\"placeholder\", \"{agent_scratchpad}\"),\n", + " ]\n", + ")\n", + "\n", + "agent = create_tool_calling_agent(llm, tools, prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `main__tools__python_exec` with `{'code': 'print(36939 * 8922.4)'}`\n", + "\n", + "\n", + "\u001b[0m\u001b[36;1m\u001b[1;3m{\"format\": \"SCALAR\", \"value\": \"329584533.59999996\\n\", \"truncated\": false}\u001b[0m\u001b[32;1m\u001b[1;3mThe result of the multiplication 36939 * 8922.4 is 329,584,533.60.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'input': '36939 * 8922.4',\n", + " 'output': 'The result of the multiplication 36939 * 8922.4 is 329,584,533.60.'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)\n", + "agent_executor.invoke({\"input\": \"36939 * 8922.4\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/tools/databricks/__init__.py b/libs/community/langchain_community/tools/databricks/__init__.py new file mode 100644 index 0000000000..9a1d5ffe53 --- /dev/null +++ b/libs/community/langchain_community/tools/databricks/__init__.py @@ -0,0 +1,3 @@ +from langchain_community.tools.databricks.tool import UCFunctionToolkit + +__all__ = ["UCFunctionToolkit"] diff --git a/libs/community/langchain_community/tools/databricks/_execution.py b/libs/community/langchain_community/tools/databricks/_execution.py new file mode 100644 index 0000000000..6cc0c66156 --- /dev/null +++ b/libs/community/langchain_community/tools/databricks/_execution.py @@ -0,0 +1,172 @@ +import json +from dataclasses import dataclass +from io import StringIO +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional + +if TYPE_CHECKING: + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.catalog import FunctionInfo + from databricks.sdk.service.sql import StatementParameterListItem + + +def is_scalar(function: "FunctionInfo") -> bool: + from databricks.sdk.service.catalog import ColumnTypeName + + return function.data_type != ColumnTypeName.TABLE_TYPE + + +@dataclass +class ParameterizedStatement: + statement: str + parameters: List["StatementParameterListItem"] + + +@dataclass +class FunctionExecutionResult: + """ + Result of executing a function. + We always use a string to present the result value for AI model to consume. + """ + + error: Optional[str] = None + format: Optional[Literal["SCALAR", "CSV"]] = None + value: Optional[str] = None + truncated: Optional[bool] = None + + def to_json(self) -> str: + data = {k: v for (k, v) in self.__dict__.items() if v is not None} + return json.dumps(data) + + +def get_execute_function_sql_stmt( + function: "FunctionInfo", json_params: Dict[str, Any] +) -> ParameterizedStatement: + from databricks.sdk.service.catalog import ColumnTypeName + from databricks.sdk.service.sql import StatementParameterListItem + + parts = [] + output_params = [] + if is_scalar(function): + # TODO: IDENTIFIER(:function) did not work + parts.append(f"SELECT {function.full_name}(") + else: + parts.append(f"SELECT * FROM {function.full_name}(") + if function.input_params is None or function.input_params.parameters is None: + assert ( + not json_params + ), "Function has no parameters but parameters were provided." + else: + args = [] + use_named_args = False + for p in function.input_params.parameters: + if p.name not in json_params: + if p.parameter_default is not None: + use_named_args = True + else: + raise ValueError( + f"Parameter {p.name} is required but not provided." + ) + else: + arg_clause = "" + if use_named_args: + arg_clause += f"{p.name} => " + json_value = json_params[p.name] + if p.type_name in ( + ColumnTypeName.ARRAY, + ColumnTypeName.MAP, + ColumnTypeName.STRUCT, + ): + # Use from_json to restore values of complex types. + json_value_str = json.dumps(json_value) + # TODO: parametrize type + arg_clause += f"from_json(:{p.name}, '{p.type_text}')" + output_params.append( + StatementParameterListItem(name=p.name, value=json_value_str) + ) + elif p.type_name == ColumnTypeName.BINARY: + # Use ubbase64 to restore binary values. + arg_clause += f"unbase64(:{p.name})" + output_params.append( + StatementParameterListItem(name=p.name, value=json_value) + ) + else: + arg_clause += f":{p.name}" + output_params.append( + StatementParameterListItem( + name=p.name, value=json_value, type=p.type_text + ) + ) + args.append(arg_clause) + parts.append(",".join(args)) + parts.append(")") + # TODO: check extra params in kwargs + statement = "".join(parts) + return ParameterizedStatement(statement=statement, parameters=output_params) + + +def execute_function( + ws: "WorkspaceClient", + warehouse_id: str, + function: "FunctionInfo", + parameters: Dict[str, Any], +) -> FunctionExecutionResult: + """ + Execute a function with the given arguments and return the result. + """ + try: + import pandas as pd + except ImportError as e: + raise ImportError( + "Could not import pandas python package. " + "Please install it with `pip install pandas`." + ) from e + from databricks.sdk.service.sql import StatementState + + # TODO: async so we can run functions in parallel + parametrized_statement = get_execute_function_sql_stmt(function, parameters) + # TODO: configurable limits + response = ws.statement_execution.execute_statement( + statement=parametrized_statement.statement, + warehouse_id=warehouse_id, + parameters=parametrized_statement.parameters, + wait_timeout="30s", + row_limit=100, + byte_limit=4096, + ) + status = response.status + assert status is not None, f"Statement execution failed: {response}" + if status.state != StatementState.SUCCEEDED: + error = status.error + assert ( + error is not None + ), "Statement execution failed but no error message was provided." + return FunctionExecutionResult(error=f"{error.error_code}: {error.message}") + manifest = response.manifest + assert manifest is not None + truncated = manifest.truncated + result = response.result + assert ( + result is not None + ), "Statement execution succeeded but no result was provided." + data_array = result.data_array + if is_scalar(function): + value = None + if data_array and len(data_array) > 0 and len(data_array[0]) > 0: + value = str(data_array[0][0]) # type: ignore + return FunctionExecutionResult( + format="SCALAR", value=value, truncated=truncated + ) + else: + schema = manifest.schema + assert ( + schema is not None and schema.columns is not None + ), "Statement execution succeeded but no schema was provided." + columns = [c.name for c in schema.columns] + if data_array is None: + data_array = [] + pdf = pd.DataFrame.from_records(data_array, columns=columns) + csv_buffer = StringIO() + pdf.to_csv(csv_buffer, index=False) + return FunctionExecutionResult( + format="CSV", value=csv_buffer.getvalue(), truncated=truncated + ) diff --git a/libs/community/langchain_community/tools/databricks/tool.py b/libs/community/langchain_community/tools/databricks/tool.py new file mode 100644 index 0000000000..33f1d9313e --- /dev/null +++ b/libs/community/langchain_community/tools/databricks/tool.py @@ -0,0 +1,201 @@ +import json +from datetime import date, datetime +from decimal import Decimal +from hashlib import md5 +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union + +from langchain_core.pydantic_v1 import BaseModel, Field, create_model +from langchain_core.tools import BaseTool, BaseToolkit, StructuredTool +from typing_extensions import Self + +if TYPE_CHECKING: + from databricks.sdk import WorkspaceClient + from databricks.sdk.service.catalog import FunctionInfo + +from langchain_community.tools.databricks._execution import execute_function + + +def _uc_type_to_pydantic_type(uc_type_json: Union[str, Dict[str, Any]]) -> Type: + mapping = { + "long": int, + "binary": bytes, + "boolean": bool, + "date": date, + "double": float, + "float": float, + "integer": int, + "short": int, + "string": str, + "timestamp": datetime, + "timestamp_ntz": datetime, + "byte": int, + } + if isinstance(uc_type_json, str): + if uc_type_json in mapping: + return mapping[uc_type_json] + else: + if uc_type_json.startswith("decimal"): + return Decimal + elif uc_type_json == "void" or uc_type_json.startswith("interval"): + raise TypeError(f"Type {uc_type_json} is not supported.") + else: + raise TypeError( + f"Unknown type {uc_type_json}. Try upgrading this package." + ) + else: + assert isinstance(uc_type_json, dict) + tpe = uc_type_json["type"] + if tpe == "array": + element_type = _uc_type_to_pydantic_type(uc_type_json["elementType"]) + if uc_type_json["containsNull"]: + element_type = Optional[element_type] # type: ignore + return List[element_type] # type: ignore + elif tpe == "map": + key_type = uc_type_json["keyType"] + assert key_type == "string", TypeError( + f"Only support STRING key type for MAP but got {key_type}." + ) + value_type = _uc_type_to_pydantic_type(uc_type_json["valueType"]) + if uc_type_json["valueContainsNull"]: + value_type: Type = Optional[value_type] # type: ignore + return Dict[str, value_type] # type: ignore + elif tpe == "struct": + fields = {} + for field in uc_type_json["fields"]: + field_type = _uc_type_to_pydantic_type(field["type"]) + if field.get("nullable"): + field_type = Optional[field_type] # type: ignore + comment = ( + uc_type_json["metadata"].get("comment") + if "metadata" in uc_type_json + else None + ) + fields[field["name"]] = (field_type, Field(..., description=comment)) + uc_type_json_str = json.dumps(uc_type_json, sort_keys=True) + type_hash = md5(uc_type_json_str.encode()).hexdigest()[:8] + return create_model(f"Struct_{type_hash}", **fields) # type: ignore + else: + raise TypeError(f"Unknown type {uc_type_json}. Try upgrading this package.") + + +def _generate_args_schema(function: "FunctionInfo") -> Type[BaseModel]: + if function.input_params is None: + return BaseModel + params = function.input_params.parameters + assert params is not None + fields = {} + for p in params: + assert p.type_json is not None + type_json = json.loads(p.type_json)["type"] + pydantic_type = _uc_type_to_pydantic_type(type_json) + description = p.comment + default: Any = ... + if p.parameter_default: + pydantic_type = Optional[pydantic_type] # type: ignore + default = None + # TODO: Convert default value string to the correct type. + # We might need to use statement execution API + # to get the JSON representation of the value. + default_description = f"(Default: {p.parameter_default})" + if description: + description += f" {default_description}" + else: + description = default_description + fields[p.name] = ( + pydantic_type, + Field(default=default, description=description), + ) + return create_model( + f"{function.catalog_name}__{function.schema_name}__{function.name}__params", + **fields, # type: ignore + ) + + +def _get_tool_name(function: "FunctionInfo") -> str: + tool_name = f"{function.catalog_name}__{function.schema_name}__{function.name}"[ + -64: + ] + return tool_name + + +def _get_default_workspace_client() -> "WorkspaceClient": + try: + from databricks.sdk import WorkspaceClient + except ImportError as e: + raise ImportError( + "Could not import databricks-sdk python package. " + "Please install it with `pip install databricks-sdk`." + ) from e + return WorkspaceClient() + + +class UCFunctionToolkit(BaseToolkit): + warehouse_id: str = Field( + description="The ID of a Databricks SQL Warehouse to execute functions." + ) + + workspace_client: "WorkspaceClient" = Field( + default_factory=_get_default_workspace_client, + description="Databricks workspace client.", + ) + + tools: Dict[str, BaseTool] = Field(default_factory=dict) + + class Config: + arbitrary_types_allowed = True + + def include(self, *function_names: str, **kwargs: Any) -> Self: + """ + Includes UC functions to the toolkit. + + Args: + functions: A list of UC function names in the format + "catalog_name.schema_name.function_name" or + "catalog_name.schema_name.*". + If the function name ends with ".*", + all functions in the schema will be added. + kwargs: Extra arguments to pass to StructuredTool, e.g., `return_direct`. + """ + for name in function_names: + if name.endswith(".*"): + catalog_name, schema_name = name[:-2].split(".") + # TODO: handle pagination, warn and truncate if too many + functions = self.workspace_client.functions.list( + catalog_name=catalog_name, schema_name=schema_name + ) + for f in functions: + assert f.full_name is not None + self.include(f.full_name, **kwargs) + else: + if name not in self.tools: + self.tools[name] = self._make_tool(name, **kwargs) + return self + + def _make_tool(self, function_name: str, **kwargs: Any) -> BaseTool: + function = self.workspace_client.functions.get(function_name) + name = _get_tool_name(function) + description = function.comment or "" + args_schema = _generate_args_schema(function) + + def func(*args: Any, **kwargs: Any) -> str: + # TODO: We expect all named args and ignore args. + # Non-empty args show up when the function has no parameters. + args_json = json.loads(json.dumps(kwargs, default=str)) + result = execute_function( + ws=self.workspace_client, + warehouse_id=self.warehouse_id, + function=function, + parameters=args_json, + ) + return result.to_json() + + return StructuredTool( + name=name, + description=description, + args_schema=args_schema, + func=func, + **kwargs, + ) + + def get_tools(self) -> List[BaseTool]: + return list(self.tools.values())