mirror of https://github.com/hwchase17/langchain
Deprecate PythonRepl tools and Pandas/Xorbits/Spark DataFrame/Python/CSV agents (#12427)
See discussion here: https://github.com/langchain-ai/langchain/discussions/11680 The code is available for usage from langchain_experimental. The reason for the deprecation is that the agents are relying on a Python REPL. The code can only be run safely with appropriate sandboxing. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/12440/head^2
parent
68e12d34a9
commit
cadfce295f
@ -0,0 +1,36 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
# Get directory of langchain package
|
||||
PACKAGE_DIR = HERE.parent
|
||||
SEPARATOR = os.sep
|
||||
|
||||
|
||||
def get_relative_path(
|
||||
file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR
|
||||
) -> str:
|
||||
"""Get the path of the file as a relative path to the package directory."""
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
return str(file.relative_to(relative_to))
|
||||
|
||||
|
||||
def as_import_path(
|
||||
file: Union[Path, str],
|
||||
*,
|
||||
suffix: Optional[str] = None,
|
||||
relative_to: Path = PACKAGE_DIR
|
||||
) -> str:
|
||||
"""Path of the file as a LangChain import exclude langchain top namespace."""
|
||||
if isinstance(file, str):
|
||||
file = Path(file)
|
||||
path = get_relative_path(file, relative_to=relative_to)
|
||||
if file.is_file():
|
||||
path = path[: -len(file.suffix)]
|
||||
import_path = path.replace(SEPARATOR, ".")
|
||||
if suffix:
|
||||
import_path += "." + suffix
|
||||
return import_path
|
@ -1 +1,22 @@
|
||||
"""CSV toolkit."""
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise ImportError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
|
@ -1,34 +0,0 @@
|
||||
from io import IOBase
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
def create_csv_agent(
|
||||
llm: BaseLanguageModel,
|
||||
path: Union[str, IOBase, List[Union[str, IOBase]]],
|
||||
pandas_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Create csv agent by loading to a dataframe and using pandas agent."""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pandas package not found, please install with `pip install pandas`"
|
||||
)
|
||||
|
||||
_kwargs = pandas_kwargs or {}
|
||||
if isinstance(path, (str, IOBase)):
|
||||
df = pd.read_csv(path, **_kwargs)
|
||||
elif isinstance(path, list):
|
||||
df = []
|
||||
for item in path:
|
||||
if not isinstance(item, (str, IOBase)):
|
||||
raise ValueError(f"Expected str or file-like object, got {type(path)}")
|
||||
df.append(pd.read_csv(item, **_kwargs))
|
||||
else:
|
||||
raise ValueError(f"Expected str, list, or file-like object, got {type(path)}")
|
||||
return create_pandas_dataframe_agent(llm, df, **kwargs)
|
@ -1 +1,22 @@
|
||||
"""Pandas toolkit."""
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise ImportError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
|
@ -1,351 +0,0 @@
|
||||
"""Agent for working with pandas objects."""
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from langchain._api import warn_deprecated
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import (
|
||||
FUNCTIONS_WITH_DF,
|
||||
FUNCTIONS_WITH_MULTI_DF,
|
||||
MULTI_DF_PREFIX,
|
||||
MULTI_DF_PREFIX_FUNCTIONS,
|
||||
PREFIX,
|
||||
PREFIX_FUNCTIONS,
|
||||
SUFFIX_NO_DF,
|
||||
SUFFIX_WITH_DF,
|
||||
SUFFIX_WITH_MULTI_DF,
|
||||
)
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.agents.types import AgentType
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
def _get_multi_prompt(
|
||||
dfs: List[Any],
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
num_dfs = len(dfs)
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
include_dfs_head = True
|
||||
elif include_df_in_prompt:
|
||||
suffix_to_use = SUFFIX_WITH_MULTI_DF
|
||||
include_dfs_head = True
|
||||
else:
|
||||
suffix_to_use = SUFFIX_NO_DF
|
||||
include_dfs_head = False
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad", "num_dfs"]
|
||||
if include_dfs_head:
|
||||
input_variables += ["dfs_head"]
|
||||
|
||||
if prefix is None:
|
||||
prefix = MULTI_DF_PREFIX
|
||||
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||
)
|
||||
|
||||
partial_prompt = prompt.partial()
|
||||
if "dfs_head" in input_variables:
|
||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs), dfs_head=dfs_head)
|
||||
if "num_dfs" in input_variables:
|
||||
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs))
|
||||
return partial_prompt, tools
|
||||
|
||||
|
||||
def _get_single_prompt(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
include_df_head = True
|
||||
elif include_df_in_prompt:
|
||||
suffix_to_use = SUFFIX_WITH_DF
|
||||
include_df_head = True
|
||||
else:
|
||||
suffix_to_use = SUFFIX_NO_DF
|
||||
include_df_head = False
|
||||
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
if include_df_head:
|
||||
input_variables += ["df_head"]
|
||||
|
||||
if prefix is None:
|
||||
prefix = PREFIX
|
||||
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||
)
|
||||
|
||||
partial_prompt = prompt.partial()
|
||||
if "df_head" in input_variables:
|
||||
partial_prompt = partial_prompt.partial(
|
||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||
)
|
||||
return partial_prompt, tools
|
||||
|
||||
|
||||
def _get_prompt_and_tools(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pandas package not found, please install with `pip install pandas`"
|
||||
)
|
||||
|
||||
if include_df_in_prompt is not None and suffix is not None:
|
||||
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
||||
|
||||
if isinstance(df, list):
|
||||
for item in df:
|
||||
if not isinstance(item, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_multi_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_single_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
|
||||
|
||||
def _get_functions_single_prompt(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
suffix_to_use = suffix_to_use.format(
|
||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||
)
|
||||
elif include_df_in_prompt:
|
||||
suffix_to_use = FUNCTIONS_WITH_DF.format(
|
||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||
)
|
||||
else:
|
||||
suffix_to_use = ""
|
||||
|
||||
if prefix is None:
|
||||
prefix = PREFIX_FUNCTIONS
|
||||
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
return prompt, tools
|
||||
|
||||
|
||||
def _get_functions_multi_prompt(
|
||||
dfs: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
dfs_head = "\n\n".join(
|
||||
[d.head(number_of_head_rows).to_markdown() for d in dfs]
|
||||
)
|
||||
suffix_to_use = suffix_to_use.format(
|
||||
dfs_head=dfs_head,
|
||||
)
|
||||
elif include_df_in_prompt:
|
||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
|
||||
dfs_head=dfs_head,
|
||||
)
|
||||
else:
|
||||
suffix_to_use = ""
|
||||
|
||||
if prefix is None:
|
||||
prefix = MULTI_DF_PREFIX_FUNCTIONS
|
||||
prefix = prefix.format(num_dfs=str(len(dfs)))
|
||||
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
return prompt, tools
|
||||
|
||||
|
||||
def _get_functions_prompt_and_tools(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pandas package not found, please install with `pip install pandas`"
|
||||
)
|
||||
if input_variables is not None:
|
||||
raise ValueError("`input_variables` is not supported at the moment.")
|
||||
|
||||
if include_df_in_prompt is not None and suffix is not None:
|
||||
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
||||
|
||||
if isinstance(df, list):
|
||||
for item in df:
|
||||
if not isinstance(item, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_functions_multi_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_functions_single_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
|
||||
|
||||
def create_pandas_dataframe_agent(
|
||||
llm: BaseLanguageModel,
|
||||
df: Any,
|
||||
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a pandas agent from an LLM and dataframe."""
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
agent: BaseSingleActionAgent
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt, base_tools = _get_prompt_and_tools(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
_prompt, base_tools = _get_functions_prompt_and_tools(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -1,44 +0,0 @@
|
||||
# flake8: noqa
|
||||
|
||||
PREFIX = """
|
||||
You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
MULTI_DF_PREFIX = """
|
||||
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc. You
|
||||
should use the tools below to answer the question posed of you:"""
|
||||
|
||||
SUFFIX_NO_DF = """
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
SUFFIX_WITH_DF = """
|
||||
This is the result of `print(df.head())`:
|
||||
{df_head}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
SUFFIX_WITH_MULTI_DF = """
|
||||
This is the result of `print(df.head())` for each dataframe:
|
||||
{dfs_head}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
PREFIX_FUNCTIONS = """
|
||||
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
|
||||
|
||||
MULTI_DF_PREFIX_FUNCTIONS = """
|
||||
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
|
||||
|
||||
FUNCTIONS_WITH_DF = """
|
||||
This is the result of `print(df.head())`:
|
||||
{df_head}"""
|
||||
|
||||
FUNCTIONS_WITH_MULTI_DF = """
|
||||
This is the result of `print(df.head())` for each dataframe:
|
||||
{dfs_head}"""
|
@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise ImportError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
@ -1,69 +0,0 @@
|
||||
"""Python agent."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain._api import warn_deprecated
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_toolkits.python.prompt import PREFIX
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.agents.types import AgentType
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools.python.tool import PythonREPLTool
|
||||
|
||||
|
||||
def create_python_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tool: PythonREPLTool,
|
||||
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
prefix: str = PREFIX,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a python agent from an LLM and tool."""
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
tools = [tool]
|
||||
agent: BaseSingleActionAgent
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
system_message = SystemMessage(content=prefix)
|
||||
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -1,9 +0,0 @@
|
||||
# flake8: noqa
|
||||
|
||||
PREFIX = """You are an agent designed to write and execute python code to answer questions.
|
||||
You have access to a python REPL, which you can use to execute python code.
|
||||
If you get an error, debug your code and try again.
|
||||
Only use the output of your code to answer the question.
|
||||
You might know the answer without running any code, but you should still run the code to get the answer.
|
||||
If it does not seem like you can write code to answer the question, just return "I don't know" as the answer.
|
||||
"""
|
@ -1 +1,22 @@
|
||||
"""spark toolkit"""
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise ImportError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
|
@ -1,91 +0,0 @@
|
||||
"""Agent for working with pandas objects."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain._api import warn_deprecated
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
def _validate_spark_df(df: Any) -> bool:
|
||||
try:
|
||||
from pyspark.sql import DataFrame as SparkLocalDataFrame
|
||||
|
||||
return isinstance(df, SparkLocalDataFrame)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _validate_spark_connect_df(df: Any) -> bool:
|
||||
try:
|
||||
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
|
||||
|
||||
return isinstance(df, SparkConnectDataFrame)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def create_spark_dataframe_agent(
|
||||
llm: BaseLLM,
|
||||
df: Any,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Spark agent from an LLM and dataframe."""
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
|
||||
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
|
||||
raise ImportError("Spark is not installed. run `pip install pyspark`.")
|
||||
|
||||
if input_variables is None:
|
||||
input_variables = ["df", "input", "agent_scratchpad"]
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
|
||||
)
|
||||
partial_prompt = prompt.partial(df=str(df.first()))
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=partial_prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -1,13 +0,0 @@
|
||||
# flake8: noqa
|
||||
|
||||
PREFIX = """
|
||||
You are working with a spark dataframe in Python. The name of the dataframe is `df`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
SUFFIX = """
|
||||
This is the result of `print(df.first())`:
|
||||
{df}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
@ -1 +1,22 @@
|
||||
"""Xorbits toolkit."""
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain._api.path import as_import_path
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
"""Get attr name."""
|
||||
|
||||
here = as_import_path(Path(__file__).parent)
|
||||
|
||||
old_path = "langchain." + here + "." + name
|
||||
new_path = "langchain_experimental." + here + "." + name
|
||||
raise ImportError(
|
||||
"This agent has been moved to langchain experiment. "
|
||||
"This agent relies on python REPL tool under the hood, so to use it "
|
||||
"safely please sandbox the python REPL. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||
)
|
||||
|
@ -1,101 +0,0 @@
|
||||
"""Agent for working with xorbits objects."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain._api import warn_deprecated
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.xorbits.prompt import (
|
||||
NP_PREFIX,
|
||||
NP_SUFFIX,
|
||||
PD_PREFIX,
|
||||
PD_SUFFIX,
|
||||
)
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
def create_xorbits_agent(
|
||||
llm: BaseLLM,
|
||||
data: Any,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentExecutor:
|
||||
"""Construct a xorbits agent from an LLM and dataframe."""
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
try:
|
||||
from xorbits import numpy as np
|
||||
from xorbits import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Xorbits package not installed, please install with `pip install xorbits`"
|
||||
)
|
||||
|
||||
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
||||
raise ValueError(
|
||||
f"Expected Xorbits DataFrame or ndarray object, got {type(data)}"
|
||||
)
|
||||
if input_variables is None:
|
||||
input_variables = ["data", "input", "agent_scratchpad"]
|
||||
tools = [PythonAstREPLTool(locals={"data": data})]
|
||||
prompt, partial_input = None, None
|
||||
|
||||
if isinstance(data, pd.DataFrame):
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=PD_PREFIX if prefix == "" else prefix,
|
||||
suffix=PD_SUFFIX if suffix == "" else suffix,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
partial_input = str(data.head())
|
||||
else:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=NP_PREFIX if prefix == "" else prefix,
|
||||
suffix=NP_SUFFIX if suffix == "" else suffix,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
partial_input = str(data[: len(data) // 2])
|
||||
partial_prompt = prompt.partial(data=partial_input)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=partial_prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -1,33 +0,0 @@
|
||||
PD_PREFIX = """
|
||||
You are working with Xorbits dataframe object in Python.
|
||||
Before importing Numpy or Pandas in the current script,
|
||||
remember to import the xorbits version of the library instead.
|
||||
To import the xorbits version of Numpy, replace the original import statement
|
||||
`import pandas as pd` with `import xorbits.pandas as pd`.
|
||||
The name of the input is `data`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
PD_SUFFIX = """
|
||||
This is the result of `print(data)`:
|
||||
{data}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
NP_PREFIX = """
|
||||
You are working with Xorbits ndarray object in Python.
|
||||
Before importing Numpy in the current script,
|
||||
remember to import the xorbits version of the library instead.
|
||||
To import the xorbits version of Numpy, replace the original import statement
|
||||
`import numpy as np` with `import xorbits.numpy as np`.
|
||||
The name of the input is `data`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
NP_SUFFIX = """
|
||||
This is the result of `print(data)`:
|
||||
{data}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
@ -0,0 +1,13 @@
|
||||
def raise_on_import() -> None:
|
||||
"""Raise on import letting users know that underlying code is deprecated."""
|
||||
raise ImportError(
|
||||
"This tool has been moved to langchain experiment. "
|
||||
"This tool has access to a python REPL. "
|
||||
"For best practices make sure to sandbox this tool. "
|
||||
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||
"To keep using this code as is, install langchain experimental and "
|
||||
"update relevant imports replacing 'langchain' with 'langchain_experimental'"
|
||||
)
|
||||
|
||||
|
||||
raise_on_import()
|
@ -1,194 +0,0 @@
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from contextlib import redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from langchain._api import warn_deprecated
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.python import PythonREPL
|
||||
|
||||
|
||||
def _get_default_python_repl() -> PythonREPL:
|
||||
return PythonREPL(_globals=globals(), _locals=None)
|
||||
|
||||
|
||||
def sanitize_input(query: str) -> str:
|
||||
"""Sanitize input to the python REPL.
|
||||
|
||||
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||
|
||||
Args:
|
||||
query: The query to sanitize
|
||||
|
||||
Returns:
|
||||
str: The sanitized query
|
||||
"""
|
||||
|
||||
# Removes `, whitespace & python from start
|
||||
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||
# Removes whitespace & ` from end
|
||||
query = re.sub(r"(\s|`)*$", "", query)
|
||||
return query
|
||||
|
||||
|
||||
class PythonREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
name: str = "Python_REPL"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands. "
|
||||
"Input should be a valid python command. "
|
||||
"If you want to see the output of a value, you should print it out "
|
||||
"with `print(...)`."
|
||||
)
|
||||
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
|
||||
sanitize_input: bool = True
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
return self.python_repl.run(query)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self.run, query)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PythonInputs(BaseModel):
|
||||
"""Python inputs."""
|
||||
|
||||
query: str = Field(description="code snippet to run")
|
||||
|
||||
|
||||
class PythonAstREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
name: str = "python_repl_ast"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands. "
|
||||
"Input should be a valid python command. "
|
||||
"When using this tool, sometimes output is abbreviated - "
|
||||
"make sure it does not look abbreviated before using it in your answer."
|
||||
)
|
||||
globals: Optional[Dict] = Field(default_factory=dict)
|
||||
locals: Optional[Dict] = Field(default_factory=dict)
|
||||
sanitize_input: bool = True
|
||||
args_schema: Type[BaseModel] = PythonInputs
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_python_version(cls, values: Dict) -> Dict:
|
||||
"""Validate valid python version."""
|
||||
if sys.version_info < (3, 9):
|
||||
raise ValueError(
|
||||
"This tool relies on Python 3.9 or higher "
|
||||
"(as it uses new functionality in the `ast` module, "
|
||||
f"you have Python version: {sys.version}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
tree = ast.parse(query)
|
||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
||||
module_end_str = ast.unparse(module_end) # type: ignore
|
||||
io_buffer = StringIO()
|
||||
try:
|
||||
with redirect_stdout(io_buffer):
|
||||
ret = eval(module_end_str, self.globals, self.locals)
|
||||
if ret is None:
|
||||
return io_buffer.getvalue()
|
||||
else:
|
||||
return ret
|
||||
except Exception:
|
||||
with redirect_stdout(io_buffer):
|
||||
exec(module_end_str, self.globals, self.locals)
|
||||
return io_buffer.getvalue()
|
||||
except Exception as e:
|
||||
return "{}: {}".format(type(e).__name__, str(e))
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
warn_deprecated(
|
||||
since="0.0.314",
|
||||
message=(
|
||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
||||
"will be available from the langchain-experimental package."
|
||||
"This code is already available in langchain-experimental."
|
||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
||||
),
|
||||
pending=True,
|
||||
)
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self._run, query)
|
||||
|
||||
return result
|
@ -1,76 +0,0 @@
|
||||
import io
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from _pytest.tmpdir import TempPathFactory
|
||||
from pandas import DataFrame
|
||||
|
||||
from langchain.agents import create_csv_agent
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def csv(tmp_path_factory: TempPathFactory) -> DataFrame:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
filename = str(tmp_path_factory.mktemp("data") / "test.csv")
|
||||
df.to_csv(filename)
|
||||
return filename
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def csv_list(tmp_path_factory: TempPathFactory) -> DataFrame:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df1 = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
filename1 = str(tmp_path_factory.mktemp("data") / "test1.csv")
|
||||
df1.to_csv(filename1)
|
||||
|
||||
random_data = np.random.rand(2, 2)
|
||||
df2 = DataFrame(random_data, columns=["name", "height"])
|
||||
filename2 = str(tmp_path_factory.mktemp("data") / "test2.csv")
|
||||
df2.to_csv(filename2)
|
||||
|
||||
return [filename1, filename2]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def csv_file_like(tmp_path_factory: TempPathFactory) -> io.BytesIO:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
buffer = io.BytesIO()
|
||||
df.to_pickle(buffer)
|
||||
return buffer
|
||||
|
||||
|
||||
def test_csv_agent_creation(csv: str) -> None:
|
||||
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
||||
|
||||
def test_single_csv(csv: str) -> None:
|
||||
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("How many rows in the csv? Give me a number.")
|
||||
result = re.search(r".*(4).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_multi_csv(csv_list: list) -> None:
|
||||
agent = create_csv_agent(OpenAI(temperature=0), csv_list, verbose=True)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("How many combined rows in the two csvs? Give me a number.")
|
||||
result = re.search(r".*(6).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_file_like(file_like: io.BytesIO) -> None:
|
||||
agent = create_csv_agent(OpenAI(temperature=0), file_like, verbose=True)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("How many rows in the csv? Give me a number.")
|
||||
result = re.search(r".*(4).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
@ -1,70 +0,0 @@
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from pandas import DataFrame
|
||||
|
||||
from langchain.agents import create_pandas_dataframe_agent
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def df() -> DataFrame:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
return df
|
||||
|
||||
|
||||
# Figure out type hint here
|
||||
@pytest.fixture(scope="module")
|
||||
def df_list() -> list:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df1 = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
random_data = np.random.rand(2, 2)
|
||||
df2 = DataFrame(random_data, columns=["name", "height"])
|
||||
df_list = [df1, df2]
|
||||
return df_list
|
||||
|
||||
|
||||
def test_pandas_agent_creation(df: DataFrame) -> None:
|
||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
||||
|
||||
def test_data_reading(df: DataFrame) -> None:
|
||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("how many rows in df? Give me a number.")
|
||||
result = re.search(rf".*({df.shape[0]}).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_data_reading_no_df_in_prompt(df: DataFrame) -> None:
|
||||
agent = create_pandas_dataframe_agent(
|
||||
OpenAI(temperature=0), df, include_df_in_prompt=False
|
||||
)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("how many rows in df? Give me a number.")
|
||||
result = re.search(rf".*({df.shape[0]}).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_multi_df(df_list: list) -> None:
|
||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df_list, verbose=True)
|
||||
response = agent.run("how many total rows in the two dataframes? Give me a number.")
|
||||
result = re.search(r".*(6).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_multi_df_no_df_in_prompt(df_list: list) -> None:
|
||||
agent = create_pandas_dataframe_agent(
|
||||
OpenAI(temperature=0), df_list, include_df_in_prompt=False
|
||||
)
|
||||
response = agent.run("how many total rows in the two dataframes? Give me a number.")
|
||||
result = re.search(r".*(6).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
|
||||
from langchain._api import path
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
ROOT = HERE.parent.parent.parent
|
||||
|
||||
|
||||
def test_as_import_path() -> None:
|
||||
"""Test that the path is converted to a LangChain import path."""
|
||||
# Verify that default paths are correct
|
||||
assert path.PACKAGE_DIR == ROOT / "langchain"
|
||||
# Verify that as import path works correctly
|
||||
assert path.as_import_path(HERE, relative_to=ROOT) == "tests.unit_tests._api"
|
||||
assert (
|
||||
path.as_import_path(__file__, relative_to=ROOT)
|
||||
== "tests.unit_tests._api.test_path"
|
||||
)
|
||||
assert (
|
||||
path.as_import_path(__file__, suffix="create_agent", relative_to=ROOT)
|
||||
== "tests.unit_tests._api.test_path.create_agent"
|
||||
)
|
@ -1,112 +0,0 @@
|
||||
"""Test functionality of Python REPL."""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
_SAMPLE_CODE = """
|
||||
```
|
||||
def multiply():
|
||||
print(5*6)
|
||||
multiply()
|
||||
```
|
||||
"""
|
||||
|
||||
_AST_SAMPLE_CODE = """
|
||||
```
|
||||
def multiply():
|
||||
return(5*6)
|
||||
multiply()
|
||||
```
|
||||
"""
|
||||
|
||||
_AST_SAMPLE_CODE_EXECUTE = """
|
||||
```
|
||||
def multiply(a, b):
|
||||
return(5*6)
|
||||
a = 5
|
||||
b = 6
|
||||
|
||||
multiply(a, b)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def test_python_repl() -> None:
|
||||
"""Test functionality when globals/locals are not provided."""
|
||||
repl = PythonREPL()
|
||||
|
||||
# Run a simple initial command.
|
||||
repl.run("foo = 1")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["foo"] == 1
|
||||
|
||||
# Now run a command that accesses `foo` to make sure it still has it.
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["bar"] == 2
|
||||
|
||||
|
||||
def test_python_repl_no_previous_variables() -> None:
|
||||
"""Test that it does not have access to variables created outside the scope."""
|
||||
foo = 3 # noqa: F841
|
||||
repl = PythonREPL()
|
||||
output = repl.run("print(foo)")
|
||||
assert output == """NameError("name 'foo' is not defined")"""
|
||||
|
||||
|
||||
def test_python_repl_pass_in_locals() -> None:
|
||||
"""Test functionality when passing in locals."""
|
||||
_locals = {"foo": 4}
|
||||
repl = PythonREPL(_locals=_locals)
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["bar"] == 8
|
||||
|
||||
|
||||
def test_functionality() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonREPL()
|
||||
code = "print(1 + 1)"
|
||||
output = chain.run(code)
|
||||
assert output == "2\n"
|
||||
|
||||
|
||||
def test_functionality_multiline() -> None:
|
||||
"""Test correct functionality for ChatGPT multiline commands."""
|
||||
chain = PythonREPL()
|
||||
tool = PythonREPLTool(python_repl=chain)
|
||||
output = tool.run(_SAMPLE_CODE)
|
||||
assert output == "30\n"
|
||||
|
||||
|
||||
def test_python_ast_repl_multiline() -> None:
|
||||
"""Test correct functionality for ChatGPT multiline commands."""
|
||||
if sys.version_info < (3, 9):
|
||||
pytest.skip("Python 3.9+ is required for this test")
|
||||
tool = PythonAstREPLTool()
|
||||
output = tool.run(_AST_SAMPLE_CODE)
|
||||
assert output == 30
|
||||
|
||||
|
||||
def test_python_ast_repl_multi_statement() -> None:
|
||||
"""Test correct functionality for ChatGPT multi statement commands."""
|
||||
if sys.version_info < (3, 9):
|
||||
pytest.skip("Python 3.9+ is required for this test")
|
||||
tool = PythonAstREPLTool()
|
||||
output = tool.run(_AST_SAMPLE_CODE_EXECUTE)
|
||||
assert output == 30
|
||||
|
||||
|
||||
def test_function() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonREPL()
|
||||
code = "def add(a, b): " " return a + b"
|
||||
output = chain.run(code)
|
||||
assert output == ""
|
||||
|
||||
code = "print(add(1, 2))"
|
||||
output = chain.run(code)
|
||||
assert output == "3\n"
|
@ -1,164 +0,0 @@
|
||||
"""Test Python REPL Tools."""
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.tools.python.tool import (
|
||||
PythonAstREPLTool,
|
||||
PythonREPLTool,
|
||||
sanitize_input,
|
||||
)
|
||||
|
||||
|
||||
def test_python_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert int(tool.run("print(1 + 1)").strip()) == 2
|
||||
|
||||
|
||||
def test_python_repl_print() -> None:
|
||||
program = """
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
print("The dot product is {:d}.".format(dot_product))
|
||||
"""
|
||||
tool = PythonREPLTool()
|
||||
assert tool.run(program) == "The dot product is 32.\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert tool.run("1 + 1") == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_return() -> None:
|
||||
program = """
|
||||
```
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
int(dot_product)
|
||||
```
|
||||
"""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == 32
|
||||
|
||||
program = """
|
||||
```python
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
int(dot_product)
|
||||
```
|
||||
"""
|
||||
assert tool.run(program) == 32
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_print() -> None:
|
||||
program = """python
|
||||
string = "racecar"
|
||||
if string == string[::-1]:
|
||||
print(string, "is a palindrome")
|
||||
else:
|
||||
print(string, "is not a palindrome")"""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "racecar is a palindrome\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_repl_print_python_backticks() -> None:
|
||||
program = "`print('`python` is a great language.')`"
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "`python` is a great language.\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_raise_exception() -> None:
|
||||
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
|
||||
program = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(data)
|
||||
df['Gender']
|
||||
"""
|
||||
tool = PythonAstREPLTool(locals={"data": data})
|
||||
expected_outputs = (
|
||||
"KeyError: 'Gender'",
|
||||
"ModuleNotFoundError: No module named 'pandas'",
|
||||
)
|
||||
assert tool.run(program) in expected_outputs
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_print() -> None:
|
||||
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "The square of 3 is 9.00\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_return() -> None:
|
||||
arr = np.array([1, 2, 3, 4, 5])
|
||||
tool = PythonAstREPLTool(locals={"arr": arr})
|
||||
program = "`(arr**2).sum() # Returns sum of squares`"
|
||||
assert tool.run(program) == 55
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_exception() -> None:
|
||||
program = "[1, 2, 3][4]"
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "IndexError: list index out of range"
|
||||
|
||||
|
||||
def test_sanitize_input() -> None:
|
||||
query = """
|
||||
```
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
```python
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
p = 5
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
Loading…
Reference in New Issue