"""Agent for working with pandas objects.""" from typing import Any, Dict, List, Optional from langchain.agents.agent import AgentExecutor from langchain.agents.mrkl.base import ZeroShotAgent from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX from langchain_experimental.tools.python.tool import PythonAstREPLTool def _validate_spark_df(df: Any) -> bool: try: from pyspark.sql import DataFrame as SparkLocalDataFrame return isinstance(df, SparkLocalDataFrame) except ImportError: return False def _validate_spark_connect_df(df: Any) -> bool: try: from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame return isinstance(df, SparkConnectDataFrame) except ImportError: return False def create_spark_dataframe_agent( llm: BaseLLM, df: Any, callback_manager: Optional[BaseCallbackManager] = None, prefix: str = PREFIX, suffix: str = SUFFIX, input_variables: Optional[List[str]] = None, verbose: bool = False, return_intermediate_steps: bool = False, max_iterations: Optional[int] = 15, max_execution_time: Optional[float] = None, early_stopping_method: str = "force", agent_executor_kwargs: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> AgentExecutor: """Construct a Spark agent from an LLM and dataframe.""" if not _validate_spark_df(df) and not _validate_spark_connect_df(df): raise ImportError("Spark is not installed. run `pip install pyspark`.") if input_variables is None: input_variables = ["df", "input", "agent_scratchpad"] tools = [PythonAstREPLTool(locals={"df": df})] prompt = ZeroShotAgent.create_prompt( tools, prefix=prefix, suffix=suffix, input_variables=input_variables ) partial_prompt = prompt.partial(df=str(df.first())) llm_chain = LLMChain( llm=llm, prompt=partial_prompt, callback_manager=callback_manager, ) tool_names = [tool.name for tool in tools] agent = ZeroShotAgent( llm_chain=llm_chain, allowed_tools=tool_names, callback_manager=callback_manager, **kwargs, ) return AgentExecutor.from_agent_and_tools( agent=agent, tools=tools, callback_manager=callback_manager, verbose=verbose, return_intermediate_steps=return_intermediate_steps, max_iterations=max_iterations, max_execution_time=max_execution_time, early_stopping_method=early_stopping_method, **(agent_executor_kwargs or {}), )