mirror of https://github.com/hwchase17/langchain
Deprecate PythonRepl tools and Pandas/Xorbits/Spark DataFrame/Python/CSV agents (#12427)
See discussion here: https://github.com/langchain-ai/langchain/discussions/11680 The code is available for usage from langchain_experimental. The reason for the deprecation is that the agents are relying on a Python REPL. The code can only be run safely with appropriate sandboxing. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/12440/head^2
parent
68e12d34a9
commit
cadfce295f
@ -0,0 +1,36 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
HERE = Path(__file__).parent
|
||||||
|
|
||||||
|
# Get directory of langchain package
|
||||||
|
PACKAGE_DIR = HERE.parent
|
||||||
|
SEPARATOR = os.sep
|
||||||
|
|
||||||
|
|
||||||
|
def get_relative_path(
|
||||||
|
file: Union[Path, str], *, relative_to: Path = PACKAGE_DIR
|
||||||
|
) -> str:
|
||||||
|
"""Get the path of the file as a relative path to the package directory."""
|
||||||
|
if isinstance(file, str):
|
||||||
|
file = Path(file)
|
||||||
|
return str(file.relative_to(relative_to))
|
||||||
|
|
||||||
|
|
||||||
|
def as_import_path(
|
||||||
|
file: Union[Path, str],
|
||||||
|
*,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
relative_to: Path = PACKAGE_DIR
|
||||||
|
) -> str:
|
||||||
|
"""Path of the file as a LangChain import exclude langchain top namespace."""
|
||||||
|
if isinstance(file, str):
|
||||||
|
file = Path(file)
|
||||||
|
path = get_relative_path(file, relative_to=relative_to)
|
||||||
|
if file.is_file():
|
||||||
|
path = path[: -len(file.suffix)]
|
||||||
|
import_path = path.replace(SEPARATOR, ".")
|
||||||
|
if suffix:
|
||||||
|
import_path += "." + suffix
|
||||||
|
return import_path
|
@ -1 +1,22 @@
|
|||||||
"""CSV toolkit."""
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain._api.path import as_import_path
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Get attr name."""
|
||||||
|
|
||||||
|
here = as_import_path(Path(__file__).parent)
|
||||||
|
|
||||||
|
old_path = "langchain." + here + "." + name
|
||||||
|
new_path = "langchain_experimental." + here + "." + name
|
||||||
|
raise ImportError(
|
||||||
|
"This agent has been moved to langchain experiment. "
|
||||||
|
"This agent relies on python REPL tool under the hood, so to use it "
|
||||||
|
"safely please sandbox the python REPL. "
|
||||||
|
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||||
|
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||||
|
"To keep using this code as is, install langchain experimental and "
|
||||||
|
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||||
|
)
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
from io import IOBase
|
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
|
||||||
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
|
|
||||||
|
|
||||||
def create_csv_agent(
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
path: Union[str, IOBase, List[Union[str, IOBase]]],
|
|
||||||
pandas_kwargs: Optional[dict] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AgentExecutor:
|
|
||||||
"""Create csv agent by loading to a dataframe and using pandas agent."""
|
|
||||||
try:
|
|
||||||
import pandas as pd
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"pandas package not found, please install with `pip install pandas`"
|
|
||||||
)
|
|
||||||
|
|
||||||
_kwargs = pandas_kwargs or {}
|
|
||||||
if isinstance(path, (str, IOBase)):
|
|
||||||
df = pd.read_csv(path, **_kwargs)
|
|
||||||
elif isinstance(path, list):
|
|
||||||
df = []
|
|
||||||
for item in path:
|
|
||||||
if not isinstance(item, (str, IOBase)):
|
|
||||||
raise ValueError(f"Expected str or file-like object, got {type(path)}")
|
|
||||||
df.append(pd.read_csv(item, **_kwargs))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Expected str, list, or file-like object, got {type(path)}")
|
|
||||||
return create_pandas_dataframe_agent(llm, df, **kwargs)
|
|
@ -1 +1,22 @@
|
|||||||
"""Pandas toolkit."""
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain._api.path import as_import_path
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Get attr name."""
|
||||||
|
|
||||||
|
here = as_import_path(Path(__file__).parent)
|
||||||
|
|
||||||
|
old_path = "langchain." + here + "." + name
|
||||||
|
new_path = "langchain_experimental." + here + "." + name
|
||||||
|
raise ImportError(
|
||||||
|
"This agent has been moved to langchain experiment. "
|
||||||
|
"This agent relies on python REPL tool under the hood, so to use it "
|
||||||
|
"safely please sandbox the python REPL. "
|
||||||
|
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||||
|
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||||
|
"To keep using this code as is, install langchain experimental and "
|
||||||
|
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||||
|
)
|
||||||
|
@ -1,351 +0,0 @@
|
|||||||
"""Agent for working with pandas objects."""
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
|
||||||
|
|
||||||
from langchain._api import warn_deprecated
|
|
||||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
|
||||||
from langchain.agents.agent_toolkits.pandas.prompt import (
|
|
||||||
FUNCTIONS_WITH_DF,
|
|
||||||
FUNCTIONS_WITH_MULTI_DF,
|
|
||||||
MULTI_DF_PREFIX,
|
|
||||||
MULTI_DF_PREFIX_FUNCTIONS,
|
|
||||||
PREFIX,
|
|
||||||
PREFIX_FUNCTIONS,
|
|
||||||
SUFFIX_NO_DF,
|
|
||||||
SUFFIX_WITH_DF,
|
|
||||||
SUFFIX_WITH_MULTI_DF,
|
|
||||||
)
|
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
|
||||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
|
||||||
from langchain.agents.types import AgentType
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.schema import BasePromptTemplate
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
from langchain.schema.messages import SystemMessage
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from langchain.tools.python.tool import PythonAstREPLTool
|
|
||||||
|
|
||||||
|
|
||||||
def _get_multi_prompt(
|
|
||||||
dfs: List[Any],
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
input_variables: Optional[List[str]] = None,
|
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
|
||||||
number_of_head_rows: int = 5,
|
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
|
||||||
num_dfs = len(dfs)
|
|
||||||
if suffix is not None:
|
|
||||||
suffix_to_use = suffix
|
|
||||||
include_dfs_head = True
|
|
||||||
elif include_df_in_prompt:
|
|
||||||
suffix_to_use = SUFFIX_WITH_MULTI_DF
|
|
||||||
include_dfs_head = True
|
|
||||||
else:
|
|
||||||
suffix_to_use = SUFFIX_NO_DF
|
|
||||||
include_dfs_head = False
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["input", "agent_scratchpad", "num_dfs"]
|
|
||||||
if include_dfs_head:
|
|
||||||
input_variables += ["dfs_head"]
|
|
||||||
|
|
||||||
if prefix is None:
|
|
||||||
prefix = MULTI_DF_PREFIX
|
|
||||||
|
|
||||||
df_locals = {}
|
|
||||||
for i, dataframe in enumerate(dfs):
|
|
||||||
df_locals[f"df{i + 1}"] = dataframe
|
|
||||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
|
||||||
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
|
||||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
|
||||||
)
|
|
||||||
|
|
||||||
partial_prompt = prompt.partial()
|
|
||||||
if "dfs_head" in input_variables:
|
|
||||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
|
||||||
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs), dfs_head=dfs_head)
|
|
||||||
if "num_dfs" in input_variables:
|
|
||||||
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs))
|
|
||||||
return partial_prompt, tools
|
|
||||||
|
|
||||||
|
|
||||||
def _get_single_prompt(
|
|
||||||
df: Any,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
input_variables: Optional[List[str]] = None,
|
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
|
||||||
number_of_head_rows: int = 5,
|
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
|
||||||
if suffix is not None:
|
|
||||||
suffix_to_use = suffix
|
|
||||||
include_df_head = True
|
|
||||||
elif include_df_in_prompt:
|
|
||||||
suffix_to_use = SUFFIX_WITH_DF
|
|
||||||
include_df_head = True
|
|
||||||
else:
|
|
||||||
suffix_to_use = SUFFIX_NO_DF
|
|
||||||
include_df_head = False
|
|
||||||
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["input", "agent_scratchpad"]
|
|
||||||
if include_df_head:
|
|
||||||
input_variables += ["df_head"]
|
|
||||||
|
|
||||||
if prefix is None:
|
|
||||||
prefix = PREFIX
|
|
||||||
|
|
||||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
|
||||||
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
|
||||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
|
||||||
)
|
|
||||||
|
|
||||||
partial_prompt = prompt.partial()
|
|
||||||
if "df_head" in input_variables:
|
|
||||||
partial_prompt = partial_prompt.partial(
|
|
||||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
|
||||||
)
|
|
||||||
return partial_prompt, tools
|
|
||||||
|
|
||||||
|
|
||||||
def _get_prompt_and_tools(
|
|
||||||
df: Any,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
input_variables: Optional[List[str]] = None,
|
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
|
||||||
number_of_head_rows: int = 5,
|
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
|
||||||
try:
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
pd.set_option("display.max_columns", None)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"pandas package not found, please install with `pip install pandas`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if include_df_in_prompt is not None and suffix is not None:
|
|
||||||
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
|
||||||
|
|
||||||
if isinstance(df, list):
|
|
||||||
for item in df:
|
|
||||||
if not isinstance(item, pd.DataFrame):
|
|
||||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
|
||||||
return _get_multi_prompt(
|
|
||||||
df,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
input_variables=input_variables,
|
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
|
||||||
number_of_head_rows=number_of_head_rows,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if not isinstance(df, pd.DataFrame):
|
|
||||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
|
||||||
return _get_single_prompt(
|
|
||||||
df,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
input_variables=input_variables,
|
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
|
||||||
number_of_head_rows=number_of_head_rows,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_functions_single_prompt(
|
|
||||||
df: Any,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
|
||||||
number_of_head_rows: int = 5,
|
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
|
||||||
if suffix is not None:
|
|
||||||
suffix_to_use = suffix
|
|
||||||
if include_df_in_prompt:
|
|
||||||
suffix_to_use = suffix_to_use.format(
|
|
||||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
|
||||||
)
|
|
||||||
elif include_df_in_prompt:
|
|
||||||
suffix_to_use = FUNCTIONS_WITH_DF.format(
|
|
||||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
suffix_to_use = ""
|
|
||||||
|
|
||||||
if prefix is None:
|
|
||||||
prefix = PREFIX_FUNCTIONS
|
|
||||||
|
|
||||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
|
||||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
|
||||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
|
||||||
return prompt, tools
|
|
||||||
|
|
||||||
|
|
||||||
def _get_functions_multi_prompt(
|
|
||||||
dfs: Any,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
|
||||||
number_of_head_rows: int = 5,
|
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
|
||||||
if suffix is not None:
|
|
||||||
suffix_to_use = suffix
|
|
||||||
if include_df_in_prompt:
|
|
||||||
dfs_head = "\n\n".join(
|
|
||||||
[d.head(number_of_head_rows).to_markdown() for d in dfs]
|
|
||||||
)
|
|
||||||
suffix_to_use = suffix_to_use.format(
|
|
||||||
dfs_head=dfs_head,
|
|
||||||
)
|
|
||||||
elif include_df_in_prompt:
|
|
||||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
|
||||||
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
|
|
||||||
dfs_head=dfs_head,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
suffix_to_use = ""
|
|
||||||
|
|
||||||
if prefix is None:
|
|
||||||
prefix = MULTI_DF_PREFIX_FUNCTIONS
|
|
||||||
prefix = prefix.format(num_dfs=str(len(dfs)))
|
|
||||||
|
|
||||||
df_locals = {}
|
|
||||||
for i, dataframe in enumerate(dfs):
|
|
||||||
df_locals[f"df{i + 1}"] = dataframe
|
|
||||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
|
||||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
|
||||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
|
||||||
return prompt, tools
|
|
||||||
|
|
||||||
|
|
||||||
def _get_functions_prompt_and_tools(
|
|
||||||
df: Any,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
input_variables: Optional[List[str]] = None,
|
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
|
||||||
number_of_head_rows: int = 5,
|
|
||||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
|
||||||
try:
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
pd.set_option("display.max_columns", None)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"pandas package not found, please install with `pip install pandas`"
|
|
||||||
)
|
|
||||||
if input_variables is not None:
|
|
||||||
raise ValueError("`input_variables` is not supported at the moment.")
|
|
||||||
|
|
||||||
if include_df_in_prompt is not None and suffix is not None:
|
|
||||||
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
|
||||||
|
|
||||||
if isinstance(df, list):
|
|
||||||
for item in df:
|
|
||||||
if not isinstance(item, pd.DataFrame):
|
|
||||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
|
||||||
return _get_functions_multi_prompt(
|
|
||||||
df,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
|
||||||
number_of_head_rows=number_of_head_rows,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if not isinstance(df, pd.DataFrame):
|
|
||||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
|
||||||
return _get_functions_single_prompt(
|
|
||||||
df,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
|
||||||
number_of_head_rows=number_of_head_rows,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_pandas_dataframe_agent(
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
df: Any,
|
|
||||||
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
prefix: Optional[str] = None,
|
|
||||||
suffix: Optional[str] = None,
|
|
||||||
input_variables: Optional[List[str]] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
return_intermediate_steps: bool = False,
|
|
||||||
max_iterations: Optional[int] = 15,
|
|
||||||
max_execution_time: Optional[float] = None,
|
|
||||||
early_stopping_method: str = "force",
|
|
||||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
include_df_in_prompt: Optional[bool] = True,
|
|
||||||
number_of_head_rows: int = 5,
|
|
||||||
extra_tools: Sequence[BaseTool] = (),
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AgentExecutor:
|
|
||||||
"""Construct a pandas agent from an LLM and dataframe."""
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
agent: BaseSingleActionAgent
|
|
||||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
|
||||||
prompt, base_tools = _get_prompt_and_tools(
|
|
||||||
df,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
input_variables=input_variables,
|
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
|
||||||
number_of_head_rows=number_of_head_rows,
|
|
||||||
)
|
|
||||||
tools = base_tools + list(extra_tools)
|
|
||||||
llm_chain = LLMChain(
|
|
||||||
llm=llm,
|
|
||||||
prompt=prompt,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
)
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
agent = ZeroShotAgent(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
allowed_tools=tool_names,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
|
||||||
_prompt, base_tools = _get_functions_prompt_and_tools(
|
|
||||||
df,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
input_variables=input_variables,
|
|
||||||
include_df_in_prompt=include_df_in_prompt,
|
|
||||||
number_of_head_rows=number_of_head_rows,
|
|
||||||
)
|
|
||||||
tools = base_tools + list(extra_tools)
|
|
||||||
agent = OpenAIFunctionsAgent(
|
|
||||||
llm=llm,
|
|
||||||
prompt=_prompt,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
|
||||||
return AgentExecutor.from_agent_and_tools(
|
|
||||||
agent=agent,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
verbose=verbose,
|
|
||||||
return_intermediate_steps=return_intermediate_steps,
|
|
||||||
max_iterations=max_iterations,
|
|
||||||
max_execution_time=max_execution_time,
|
|
||||||
early_stopping_method=early_stopping_method,
|
|
||||||
**(agent_executor_kwargs or {}),
|
|
||||||
)
|
|
@ -1,44 +0,0 @@
|
|||||||
# flake8: noqa
|
|
||||||
|
|
||||||
PREFIX = """
|
|
||||||
You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
|
|
||||||
You should use the tools below to answer the question posed of you:"""
|
|
||||||
|
|
||||||
MULTI_DF_PREFIX = """
|
|
||||||
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc. You
|
|
||||||
should use the tools below to answer the question posed of you:"""
|
|
||||||
|
|
||||||
SUFFIX_NO_DF = """
|
|
||||||
Begin!
|
|
||||||
Question: {input}
|
|
||||||
{agent_scratchpad}"""
|
|
||||||
|
|
||||||
SUFFIX_WITH_DF = """
|
|
||||||
This is the result of `print(df.head())`:
|
|
||||||
{df_head}
|
|
||||||
|
|
||||||
Begin!
|
|
||||||
Question: {input}
|
|
||||||
{agent_scratchpad}"""
|
|
||||||
|
|
||||||
SUFFIX_WITH_MULTI_DF = """
|
|
||||||
This is the result of `print(df.head())` for each dataframe:
|
|
||||||
{dfs_head}
|
|
||||||
|
|
||||||
Begin!
|
|
||||||
Question: {input}
|
|
||||||
{agent_scratchpad}"""
|
|
||||||
|
|
||||||
PREFIX_FUNCTIONS = """
|
|
||||||
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
|
|
||||||
|
|
||||||
MULTI_DF_PREFIX_FUNCTIONS = """
|
|
||||||
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
|
|
||||||
|
|
||||||
FUNCTIONS_WITH_DF = """
|
|
||||||
This is the result of `print(df.head())`:
|
|
||||||
{df_head}"""
|
|
||||||
|
|
||||||
FUNCTIONS_WITH_MULTI_DF = """
|
|
||||||
This is the result of `print(df.head())` for each dataframe:
|
|
||||||
{dfs_head}"""
|
|
@ -0,0 +1,22 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain._api.path import as_import_path
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Get attr name."""
|
||||||
|
|
||||||
|
here = as_import_path(Path(__file__).parent)
|
||||||
|
|
||||||
|
old_path = "langchain." + here + "." + name
|
||||||
|
new_path = "langchain_experimental." + here + "." + name
|
||||||
|
raise ImportError(
|
||||||
|
"This agent has been moved to langchain experiment. "
|
||||||
|
"This agent relies on python REPL tool under the hood, so to use it "
|
||||||
|
"safely please sandbox the python REPL. "
|
||||||
|
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||||
|
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||||
|
"To keep using this code as is, install langchain experimental and "
|
||||||
|
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||||
|
)
|
@ -1,69 +0,0 @@
|
|||||||
"""Python agent."""
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from langchain._api import warn_deprecated
|
|
||||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
|
||||||
from langchain.agents.agent_toolkits.python.prompt import PREFIX
|
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
|
||||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
|
||||||
from langchain.agents.types import AgentType
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
from langchain.schema.messages import SystemMessage
|
|
||||||
from langchain.tools.python.tool import PythonREPLTool
|
|
||||||
|
|
||||||
|
|
||||||
def create_python_agent(
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
tool: PythonREPLTool,
|
|
||||||
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AgentExecutor:
|
|
||||||
"""Construct a python agent from an LLM and tool."""
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
tools = [tool]
|
|
||||||
agent: BaseSingleActionAgent
|
|
||||||
|
|
||||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
|
||||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
|
||||||
llm_chain = LLMChain(
|
|
||||||
llm=llm,
|
|
||||||
prompt=prompt,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
)
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
|
||||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
|
||||||
system_message = SystemMessage(content=prefix)
|
|
||||||
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
|
||||||
agent = OpenAIFunctionsAgent(
|
|
||||||
llm=llm,
|
|
||||||
prompt=_prompt,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
|
||||||
return AgentExecutor.from_agent_and_tools(
|
|
||||||
agent=agent,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
verbose=verbose,
|
|
||||||
**(agent_executor_kwargs or {}),
|
|
||||||
)
|
|
@ -1,9 +0,0 @@
|
|||||||
# flake8: noqa
|
|
||||||
|
|
||||||
PREFIX = """You are an agent designed to write and execute python code to answer questions.
|
|
||||||
You have access to a python REPL, which you can use to execute python code.
|
|
||||||
If you get an error, debug your code and try again.
|
|
||||||
Only use the output of your code to answer the question.
|
|
||||||
You might know the answer without running any code, but you should still run the code to get the answer.
|
|
||||||
If it does not seem like you can write code to answer the question, just return "I don't know" as the answer.
|
|
||||||
"""
|
|
@ -1 +1,22 @@
|
|||||||
"""spark toolkit"""
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain._api.path import as_import_path
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Get attr name."""
|
||||||
|
|
||||||
|
here = as_import_path(Path(__file__).parent)
|
||||||
|
|
||||||
|
old_path = "langchain." + here + "." + name
|
||||||
|
new_path = "langchain_experimental." + here + "." + name
|
||||||
|
raise ImportError(
|
||||||
|
"This agent has been moved to langchain experiment. "
|
||||||
|
"This agent relies on python REPL tool under the hood, so to use it "
|
||||||
|
"safely please sandbox the python REPL. "
|
||||||
|
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||||
|
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||||
|
"To keep using this code as is, install langchain experimental and "
|
||||||
|
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||||
|
)
|
||||||
|
@ -1,91 +0,0 @@
|
|||||||
"""Agent for working with pandas objects."""
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain._api import warn_deprecated
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
|
||||||
from langchain.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
|
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.llms.base import BaseLLM
|
|
||||||
from langchain.tools.python.tool import PythonAstREPLTool
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_spark_df(df: Any) -> bool:
|
|
||||||
try:
|
|
||||||
from pyspark.sql import DataFrame as SparkLocalDataFrame
|
|
||||||
|
|
||||||
return isinstance(df, SparkLocalDataFrame)
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_spark_connect_df(df: Any) -> bool:
|
|
||||||
try:
|
|
||||||
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
|
|
||||||
|
|
||||||
return isinstance(df, SparkConnectDataFrame)
|
|
||||||
except ImportError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def create_spark_dataframe_agent(
|
|
||||||
llm: BaseLLM,
|
|
||||||
df: Any,
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
suffix: str = SUFFIX,
|
|
||||||
input_variables: Optional[List[str]] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
return_intermediate_steps: bool = False,
|
|
||||||
max_iterations: Optional[int] = 15,
|
|
||||||
max_execution_time: Optional[float] = None,
|
|
||||||
early_stopping_method: str = "force",
|
|
||||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AgentExecutor:
|
|
||||||
"""Construct a Spark agent from an LLM and dataframe."""
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
|
|
||||||
raise ImportError("Spark is not installed. run `pip install pyspark`.")
|
|
||||||
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["df", "input", "agent_scratchpad"]
|
|
||||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
|
||||||
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
|
|
||||||
)
|
|
||||||
partial_prompt = prompt.partial(df=str(df.first()))
|
|
||||||
llm_chain = LLMChain(
|
|
||||||
llm=llm,
|
|
||||||
prompt=partial_prompt,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
)
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
agent = ZeroShotAgent(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
allowed_tools=tool_names,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return AgentExecutor.from_agent_and_tools(
|
|
||||||
agent=agent,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
verbose=verbose,
|
|
||||||
return_intermediate_steps=return_intermediate_steps,
|
|
||||||
max_iterations=max_iterations,
|
|
||||||
max_execution_time=max_execution_time,
|
|
||||||
early_stopping_method=early_stopping_method,
|
|
||||||
**(agent_executor_kwargs or {}),
|
|
||||||
)
|
|
@ -1,13 +0,0 @@
|
|||||||
# flake8: noqa
|
|
||||||
|
|
||||||
PREFIX = """
|
|
||||||
You are working with a spark dataframe in Python. The name of the dataframe is `df`.
|
|
||||||
You should use the tools below to answer the question posed of you:"""
|
|
||||||
|
|
||||||
SUFFIX = """
|
|
||||||
This is the result of `print(df.first())`:
|
|
||||||
{df}
|
|
||||||
|
|
||||||
Begin!
|
|
||||||
Question: {input}
|
|
||||||
{agent_scratchpad}"""
|
|
@ -1 +1,22 @@
|
|||||||
"""Xorbits toolkit."""
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain._api.path import as_import_path
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str) -> Any:
|
||||||
|
"""Get attr name."""
|
||||||
|
|
||||||
|
here = as_import_path(Path(__file__).parent)
|
||||||
|
|
||||||
|
old_path = "langchain." + here + "." + name
|
||||||
|
new_path = "langchain_experimental." + here + "." + name
|
||||||
|
raise ImportError(
|
||||||
|
"This agent has been moved to langchain experiment. "
|
||||||
|
"This agent relies on python REPL tool under the hood, so to use it "
|
||||||
|
"safely please sandbox the python REPL. "
|
||||||
|
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||||
|
"and https://github.com/langchain-ai/langchain/discussions/11680"
|
||||||
|
"To keep using this code as is, install langchain experimental and "
|
||||||
|
f"update your import statement from:\n `{old_path}` to `{new_path}`."
|
||||||
|
)
|
||||||
|
@ -1,101 +0,0 @@
|
|||||||
"""Agent for working with xorbits objects."""
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain._api import warn_deprecated
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
|
||||||
from langchain.agents.agent_toolkits.xorbits.prompt import (
|
|
||||||
NP_PREFIX,
|
|
||||||
NP_SUFFIX,
|
|
||||||
PD_PREFIX,
|
|
||||||
PD_SUFFIX,
|
|
||||||
)
|
|
||||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.llms.base import BaseLLM
|
|
||||||
from langchain.tools.python.tool import PythonAstREPLTool
|
|
||||||
|
|
||||||
|
|
||||||
def create_xorbits_agent(
|
|
||||||
llm: BaseLLM,
|
|
||||||
data: Any,
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
suffix: str = "",
|
|
||||||
input_variables: Optional[List[str]] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
return_intermediate_steps: bool = False,
|
|
||||||
max_iterations: Optional[int] = 15,
|
|
||||||
max_execution_time: Optional[float] = None,
|
|
||||||
early_stopping_method: str = "force",
|
|
||||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> AgentExecutor:
|
|
||||||
"""Construct a xorbits agent from an LLM and dataframe."""
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from xorbits import numpy as np
|
|
||||||
from xorbits import pandas as pd
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Xorbits package not installed, please install with `pip install xorbits`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected Xorbits DataFrame or ndarray object, got {type(data)}"
|
|
||||||
)
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["data", "input", "agent_scratchpad"]
|
|
||||||
tools = [PythonAstREPLTool(locals={"data": data})]
|
|
||||||
prompt, partial_input = None, None
|
|
||||||
|
|
||||||
if isinstance(data, pd.DataFrame):
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
|
||||||
tools,
|
|
||||||
prefix=PD_PREFIX if prefix == "" else prefix,
|
|
||||||
suffix=PD_SUFFIX if suffix == "" else suffix,
|
|
||||||
input_variables=input_variables,
|
|
||||||
)
|
|
||||||
partial_input = str(data.head())
|
|
||||||
else:
|
|
||||||
prompt = ZeroShotAgent.create_prompt(
|
|
||||||
tools,
|
|
||||||
prefix=NP_PREFIX if prefix == "" else prefix,
|
|
||||||
suffix=NP_SUFFIX if suffix == "" else suffix,
|
|
||||||
input_variables=input_variables,
|
|
||||||
)
|
|
||||||
partial_input = str(data[: len(data) // 2])
|
|
||||||
partial_prompt = prompt.partial(data=partial_input)
|
|
||||||
llm_chain = LLMChain(
|
|
||||||
llm=llm,
|
|
||||||
prompt=partial_prompt,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
)
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
agent = ZeroShotAgent(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
allowed_tools=tool_names,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return AgentExecutor.from_agent_and_tools(
|
|
||||||
agent=agent,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
verbose=verbose,
|
|
||||||
return_intermediate_steps=return_intermediate_steps,
|
|
||||||
max_iterations=max_iterations,
|
|
||||||
max_execution_time=max_execution_time,
|
|
||||||
early_stopping_method=early_stopping_method,
|
|
||||||
**(agent_executor_kwargs or {}),
|
|
||||||
)
|
|
@ -1,33 +0,0 @@
|
|||||||
PD_PREFIX = """
|
|
||||||
You are working with Xorbits dataframe object in Python.
|
|
||||||
Before importing Numpy or Pandas in the current script,
|
|
||||||
remember to import the xorbits version of the library instead.
|
|
||||||
To import the xorbits version of Numpy, replace the original import statement
|
|
||||||
`import pandas as pd` with `import xorbits.pandas as pd`.
|
|
||||||
The name of the input is `data`.
|
|
||||||
You should use the tools below to answer the question posed of you:"""
|
|
||||||
|
|
||||||
PD_SUFFIX = """
|
|
||||||
This is the result of `print(data)`:
|
|
||||||
{data}
|
|
||||||
|
|
||||||
Begin!
|
|
||||||
Question: {input}
|
|
||||||
{agent_scratchpad}"""
|
|
||||||
|
|
||||||
NP_PREFIX = """
|
|
||||||
You are working with Xorbits ndarray object in Python.
|
|
||||||
Before importing Numpy in the current script,
|
|
||||||
remember to import the xorbits version of the library instead.
|
|
||||||
To import the xorbits version of Numpy, replace the original import statement
|
|
||||||
`import numpy as np` with `import xorbits.numpy as np`.
|
|
||||||
The name of the input is `data`.
|
|
||||||
You should use the tools below to answer the question posed of you:"""
|
|
||||||
|
|
||||||
NP_SUFFIX = """
|
|
||||||
This is the result of `print(data)`:
|
|
||||||
{data}
|
|
||||||
|
|
||||||
Begin!
|
|
||||||
Question: {input}
|
|
||||||
{agent_scratchpad}"""
|
|
@ -0,0 +1,13 @@
|
|||||||
|
def raise_on_import() -> None:
|
||||||
|
"""Raise on import letting users know that underlying code is deprecated."""
|
||||||
|
raise ImportError(
|
||||||
|
"This tool has been moved to langchain experiment. "
|
||||||
|
"This tool has access to a python REPL. "
|
||||||
|
"For best practices make sure to sandbox this tool. "
|
||||||
|
"Read https://github.com/langchain-ai/langchain/blob/master/SECURITY.md "
|
||||||
|
"To keep using this code as is, install langchain experimental and "
|
||||||
|
"update relevant imports replacing 'langchain' with 'langchain_experimental'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
raise_on_import()
|
@ -1,194 +0,0 @@
|
|||||||
"""A tool for running python code in a REPL."""
|
|
||||||
|
|
||||||
import ast
|
|
||||||
import asyncio
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from contextlib import redirect_stdout
|
|
||||||
from io import StringIO
|
|
||||||
from typing import Any, Dict, Optional, Type
|
|
||||||
|
|
||||||
from langchain._api import warn_deprecated
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForToolRun,
|
|
||||||
CallbackManagerForToolRun,
|
|
||||||
)
|
|
||||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
|
||||||
from langchain.tools.base import BaseTool
|
|
||||||
from langchain.utilities.python import PythonREPL
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_python_repl() -> PythonREPL:
|
|
||||||
return PythonREPL(_globals=globals(), _locals=None)
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_input(query: str) -> str:
|
|
||||||
"""Sanitize input to the python REPL.
|
|
||||||
|
|
||||||
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The query to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The sanitized query
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Removes `, whitespace & python from start
|
|
||||||
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
|
||||||
# Removes whitespace & ` from end
|
|
||||||
query = re.sub(r"(\s|`)*$", "", query)
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
class PythonREPLTool(BaseTool):
|
|
||||||
"""A tool for running python code in a REPL."""
|
|
||||||
|
|
||||||
name: str = "Python_REPL"
|
|
||||||
description: str = (
|
|
||||||
"A Python shell. Use this to execute python commands. "
|
|
||||||
"Input should be a valid python command. "
|
|
||||||
"If you want to see the output of a value, you should print it out "
|
|
||||||
"with `print(...)`."
|
|
||||||
)
|
|
||||||
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
|
|
||||||
sanitize_input: bool = True
|
|
||||||
|
|
||||||
def _run(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Use the tool."""
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
if self.sanitize_input:
|
|
||||||
query = sanitize_input(query)
|
|
||||||
return self.python_repl.run(query)
|
|
||||||
|
|
||||||
async def _arun(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Use the tool asynchronously."""
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
if self.sanitize_input:
|
|
||||||
query = sanitize_input(query)
|
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
result = await loop.run_in_executor(None, self.run, query)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class PythonInputs(BaseModel):
|
|
||||||
"""Python inputs."""
|
|
||||||
|
|
||||||
query: str = Field(description="code snippet to run")
|
|
||||||
|
|
||||||
|
|
||||||
class PythonAstREPLTool(BaseTool):
|
|
||||||
"""A tool for running python code in a REPL."""
|
|
||||||
|
|
||||||
name: str = "python_repl_ast"
|
|
||||||
description: str = (
|
|
||||||
"A Python shell. Use this to execute python commands. "
|
|
||||||
"Input should be a valid python command. "
|
|
||||||
"When using this tool, sometimes output is abbreviated - "
|
|
||||||
"make sure it does not look abbreviated before using it in your answer."
|
|
||||||
)
|
|
||||||
globals: Optional[Dict] = Field(default_factory=dict)
|
|
||||||
locals: Optional[Dict] = Field(default_factory=dict)
|
|
||||||
sanitize_input: bool = True
|
|
||||||
args_schema: Type[BaseModel] = PythonInputs
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def validate_python_version(cls, values: Dict) -> Dict:
|
|
||||||
"""Validate valid python version."""
|
|
||||||
if sys.version_info < (3, 9):
|
|
||||||
raise ValueError(
|
|
||||||
"This tool relies on Python 3.9 or higher "
|
|
||||||
"(as it uses new functionality in the `ast` module, "
|
|
||||||
f"you have Python version: {sys.version}"
|
|
||||||
)
|
|
||||||
return values
|
|
||||||
|
|
||||||
def _run(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
||||||
) -> str:
|
|
||||||
"""Use the tool."""
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if self.sanitize_input:
|
|
||||||
query = sanitize_input(query)
|
|
||||||
tree = ast.parse(query)
|
|
||||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
|
||||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
|
||||||
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
|
||||||
module_end_str = ast.unparse(module_end) # type: ignore
|
|
||||||
io_buffer = StringIO()
|
|
||||||
try:
|
|
||||||
with redirect_stdout(io_buffer):
|
|
||||||
ret = eval(module_end_str, self.globals, self.locals)
|
|
||||||
if ret is None:
|
|
||||||
return io_buffer.getvalue()
|
|
||||||
else:
|
|
||||||
return ret
|
|
||||||
except Exception:
|
|
||||||
with redirect_stdout(io_buffer):
|
|
||||||
exec(module_end_str, self.globals, self.locals)
|
|
||||||
return io_buffer.getvalue()
|
|
||||||
except Exception as e:
|
|
||||||
return "{}: {}".format(type(e).__name__, str(e))
|
|
||||||
|
|
||||||
async def _arun(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Use the tool asynchronously."""
|
|
||||||
|
|
||||||
warn_deprecated(
|
|
||||||
since="0.0.314",
|
|
||||||
message=(
|
|
||||||
"On 2023-10-27 this module will be be deprecated from langchain, and "
|
|
||||||
"will be available from the langchain-experimental package."
|
|
||||||
"This code is already available in langchain-experimental."
|
|
||||||
"See https://github.com/langchain-ai/langchain/discussions/11680."
|
|
||||||
),
|
|
||||||
pending=True,
|
|
||||||
)
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
result = await loop.run_in_executor(None, self._run, query)
|
|
||||||
|
|
||||||
return result
|
|
@ -1,76 +0,0 @@
|
|||||||
import io
|
|
||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from _pytest.tmpdir import TempPathFactory
|
|
||||||
from pandas import DataFrame
|
|
||||||
|
|
||||||
from langchain.agents import create_csv_agent
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
|
||||||
from langchain.llms import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def csv(tmp_path_factory: TempPathFactory) -> DataFrame:
|
|
||||||
random_data = np.random.rand(4, 4)
|
|
||||||
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
|
||||||
filename = str(tmp_path_factory.mktemp("data") / "test.csv")
|
|
||||||
df.to_csv(filename)
|
|
||||||
return filename
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def csv_list(tmp_path_factory: TempPathFactory) -> DataFrame:
|
|
||||||
random_data = np.random.rand(4, 4)
|
|
||||||
df1 = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
|
||||||
filename1 = str(tmp_path_factory.mktemp("data") / "test1.csv")
|
|
||||||
df1.to_csv(filename1)
|
|
||||||
|
|
||||||
random_data = np.random.rand(2, 2)
|
|
||||||
df2 = DataFrame(random_data, columns=["name", "height"])
|
|
||||||
filename2 = str(tmp_path_factory.mktemp("data") / "test2.csv")
|
|
||||||
df2.to_csv(filename2)
|
|
||||||
|
|
||||||
return [filename1, filename2]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def csv_file_like(tmp_path_factory: TempPathFactory) -> io.BytesIO:
|
|
||||||
random_data = np.random.rand(4, 4)
|
|
||||||
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
df.to_pickle(buffer)
|
|
||||||
return buffer
|
|
||||||
|
|
||||||
|
|
||||||
def test_csv_agent_creation(csv: str) -> None:
|
|
||||||
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
|
||||||
assert isinstance(agent, AgentExecutor)
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_csv(csv: str) -> None:
|
|
||||||
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
|
||||||
assert isinstance(agent, AgentExecutor)
|
|
||||||
response = agent.run("How many rows in the csv? Give me a number.")
|
|
||||||
result = re.search(r".*(4).*", response)
|
|
||||||
assert result is not None
|
|
||||||
assert result.group(1) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_csv(csv_list: list) -> None:
|
|
||||||
agent = create_csv_agent(OpenAI(temperature=0), csv_list, verbose=True)
|
|
||||||
assert isinstance(agent, AgentExecutor)
|
|
||||||
response = agent.run("How many combined rows in the two csvs? Give me a number.")
|
|
||||||
result = re.search(r".*(6).*", response)
|
|
||||||
assert result is not None
|
|
||||||
assert result.group(1) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_file_like(file_like: io.BytesIO) -> None:
|
|
||||||
agent = create_csv_agent(OpenAI(temperature=0), file_like, verbose=True)
|
|
||||||
assert isinstance(agent, AgentExecutor)
|
|
||||||
response = agent.run("How many rows in the csv? Give me a number.")
|
|
||||||
result = re.search(r".*(4).*", response)
|
|
||||||
assert result is not None
|
|
||||||
assert result.group(1) is not None
|
|
@ -1,70 +0,0 @@
|
|||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from pandas import DataFrame
|
|
||||||
|
|
||||||
from langchain.agents import create_pandas_dataframe_agent
|
|
||||||
from langchain.agents.agent import AgentExecutor
|
|
||||||
from langchain.llms import OpenAI
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def df() -> DataFrame:
|
|
||||||
random_data = np.random.rand(4, 4)
|
|
||||||
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
# Figure out type hint here
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def df_list() -> list:
|
|
||||||
random_data = np.random.rand(4, 4)
|
|
||||||
df1 = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
|
||||||
random_data = np.random.rand(2, 2)
|
|
||||||
df2 = DataFrame(random_data, columns=["name", "height"])
|
|
||||||
df_list = [df1, df2]
|
|
||||||
return df_list
|
|
||||||
|
|
||||||
|
|
||||||
def test_pandas_agent_creation(df: DataFrame) -> None:
|
|
||||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df)
|
|
||||||
assert isinstance(agent, AgentExecutor)
|
|
||||||
|
|
||||||
|
|
||||||
def test_data_reading(df: DataFrame) -> None:
|
|
||||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df)
|
|
||||||
assert isinstance(agent, AgentExecutor)
|
|
||||||
response = agent.run("how many rows in df? Give me a number.")
|
|
||||||
result = re.search(rf".*({df.shape[0]}).*", response)
|
|
||||||
assert result is not None
|
|
||||||
assert result.group(1) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_data_reading_no_df_in_prompt(df: DataFrame) -> None:
|
|
||||||
agent = create_pandas_dataframe_agent(
|
|
||||||
OpenAI(temperature=0), df, include_df_in_prompt=False
|
|
||||||
)
|
|
||||||
assert isinstance(agent, AgentExecutor)
|
|
||||||
response = agent.run("how many rows in df? Give me a number.")
|
|
||||||
result = re.search(rf".*({df.shape[0]}).*", response)
|
|
||||||
assert result is not None
|
|
||||||
assert result.group(1) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_df(df_list: list) -> None:
|
|
||||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df_list, verbose=True)
|
|
||||||
response = agent.run("how many total rows in the two dataframes? Give me a number.")
|
|
||||||
result = re.search(r".*(6).*", response)
|
|
||||||
assert result is not None
|
|
||||||
assert result.group(1) is not None
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_df_no_df_in_prompt(df_list: list) -> None:
|
|
||||||
agent = create_pandas_dataframe_agent(
|
|
||||||
OpenAI(temperature=0), df_list, include_df_in_prompt=False
|
|
||||||
)
|
|
||||||
response = agent.run("how many total rows in the two dataframes? Give me a number.")
|
|
||||||
result = re.search(r".*(6).*", response)
|
|
||||||
assert result is not None
|
|
||||||
assert result.group(1) is not None
|
|
@ -0,0 +1,23 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from langchain._api import path
|
||||||
|
|
||||||
|
HERE = Path(__file__).parent
|
||||||
|
|
||||||
|
ROOT = HERE.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_as_import_path() -> None:
|
||||||
|
"""Test that the path is converted to a LangChain import path."""
|
||||||
|
# Verify that default paths are correct
|
||||||
|
assert path.PACKAGE_DIR == ROOT / "langchain"
|
||||||
|
# Verify that as import path works correctly
|
||||||
|
assert path.as_import_path(HERE, relative_to=ROOT) == "tests.unit_tests._api"
|
||||||
|
assert (
|
||||||
|
path.as_import_path(__file__, relative_to=ROOT)
|
||||||
|
== "tests.unit_tests._api.test_path"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
path.as_import_path(__file__, suffix="create_agent", relative_to=ROOT)
|
||||||
|
== "tests.unit_tests._api.test_path.create_agent"
|
||||||
|
)
|
@ -1,112 +0,0 @@
|
|||||||
"""Test functionality of Python REPL."""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
|
||||||
from langchain.utilities import PythonREPL
|
|
||||||
|
|
||||||
_SAMPLE_CODE = """
|
|
||||||
```
|
|
||||||
def multiply():
|
|
||||||
print(5*6)
|
|
||||||
multiply()
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
_AST_SAMPLE_CODE = """
|
|
||||||
```
|
|
||||||
def multiply():
|
|
||||||
return(5*6)
|
|
||||||
multiply()
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
_AST_SAMPLE_CODE_EXECUTE = """
|
|
||||||
```
|
|
||||||
def multiply(a, b):
|
|
||||||
return(5*6)
|
|
||||||
a = 5
|
|
||||||
b = 6
|
|
||||||
|
|
||||||
multiply(a, b)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_repl() -> None:
|
|
||||||
"""Test functionality when globals/locals are not provided."""
|
|
||||||
repl = PythonREPL()
|
|
||||||
|
|
||||||
# Run a simple initial command.
|
|
||||||
repl.run("foo = 1")
|
|
||||||
assert repl.locals is not None
|
|
||||||
assert repl.locals["foo"] == 1
|
|
||||||
|
|
||||||
# Now run a command that accesses `foo` to make sure it still has it.
|
|
||||||
repl.run("bar = foo * 2")
|
|
||||||
assert repl.locals is not None
|
|
||||||
assert repl.locals["bar"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_repl_no_previous_variables() -> None:
|
|
||||||
"""Test that it does not have access to variables created outside the scope."""
|
|
||||||
foo = 3 # noqa: F841
|
|
||||||
repl = PythonREPL()
|
|
||||||
output = repl.run("print(foo)")
|
|
||||||
assert output == """NameError("name 'foo' is not defined")"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_repl_pass_in_locals() -> None:
|
|
||||||
"""Test functionality when passing in locals."""
|
|
||||||
_locals = {"foo": 4}
|
|
||||||
repl = PythonREPL(_locals=_locals)
|
|
||||||
repl.run("bar = foo * 2")
|
|
||||||
assert repl.locals is not None
|
|
||||||
assert repl.locals["bar"] == 8
|
|
||||||
|
|
||||||
|
|
||||||
def test_functionality() -> None:
|
|
||||||
"""Test correct functionality."""
|
|
||||||
chain = PythonREPL()
|
|
||||||
code = "print(1 + 1)"
|
|
||||||
output = chain.run(code)
|
|
||||||
assert output == "2\n"
|
|
||||||
|
|
||||||
|
|
||||||
def test_functionality_multiline() -> None:
|
|
||||||
"""Test correct functionality for ChatGPT multiline commands."""
|
|
||||||
chain = PythonREPL()
|
|
||||||
tool = PythonREPLTool(python_repl=chain)
|
|
||||||
output = tool.run(_SAMPLE_CODE)
|
|
||||||
assert output == "30\n"
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_ast_repl_multiline() -> None:
|
|
||||||
"""Test correct functionality for ChatGPT multiline commands."""
|
|
||||||
if sys.version_info < (3, 9):
|
|
||||||
pytest.skip("Python 3.9+ is required for this test")
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
output = tool.run(_AST_SAMPLE_CODE)
|
|
||||||
assert output == 30
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_ast_repl_multi_statement() -> None:
|
|
||||||
"""Test correct functionality for ChatGPT multi statement commands."""
|
|
||||||
if sys.version_info < (3, 9):
|
|
||||||
pytest.skip("Python 3.9+ is required for this test")
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
output = tool.run(_AST_SAMPLE_CODE_EXECUTE)
|
|
||||||
assert output == 30
|
|
||||||
|
|
||||||
|
|
||||||
def test_function() -> None:
|
|
||||||
"""Test correct functionality."""
|
|
||||||
chain = PythonREPL()
|
|
||||||
code = "def add(a, b): " " return a + b"
|
|
||||||
output = chain.run(code)
|
|
||||||
assert output == ""
|
|
||||||
|
|
||||||
code = "print(add(1, 2))"
|
|
||||||
output = chain.run(code)
|
|
||||||
assert output == "3\n"
|
|
@ -1,164 +0,0 @@
|
|||||||
"""Test Python REPL Tools."""
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from langchain.tools.python.tool import (
|
|
||||||
PythonAstREPLTool,
|
|
||||||
PythonREPLTool,
|
|
||||||
sanitize_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_repl_tool_single_input() -> None:
|
|
||||||
"""Test that the python REPL tool works with a single input."""
|
|
||||||
tool = PythonREPLTool()
|
|
||||||
assert tool.is_single_input
|
|
||||||
assert int(tool.run("print(1 + 1)").strip()) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_python_repl_print() -> None:
|
|
||||||
program = """
|
|
||||||
import numpy as np
|
|
||||||
v1 = np.array([1, 2, 3])
|
|
||||||
v2 = np.array([4, 5, 6])
|
|
||||||
dot_product = np.dot(v1, v2)
|
|
||||||
print("The dot product is {:d}.".format(dot_product))
|
|
||||||
"""
|
|
||||||
tool = PythonREPLTool()
|
|
||||||
assert tool.run(program) == "The dot product is 32.\n"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_python_ast_repl_tool_single_input() -> None:
|
|
||||||
"""Test that the python REPL tool works with a single input."""
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
assert tool.is_single_input
|
|
||||||
assert tool.run("1 + 1") == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_python_ast_repl_return() -> None:
|
|
||||||
program = """
|
|
||||||
```
|
|
||||||
import numpy as np
|
|
||||||
v1 = np.array([1, 2, 3])
|
|
||||||
v2 = np.array([4, 5, 6])
|
|
||||||
dot_product = np.dot(v1, v2)
|
|
||||||
int(dot_product)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
assert tool.run(program) == 32
|
|
||||||
|
|
||||||
program = """
|
|
||||||
```python
|
|
||||||
import numpy as np
|
|
||||||
v1 = np.array([1, 2, 3])
|
|
||||||
v2 = np.array([4, 5, 6])
|
|
||||||
dot_product = np.dot(v1, v2)
|
|
||||||
int(dot_product)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
assert tool.run(program) == 32
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_python_ast_repl_print() -> None:
|
|
||||||
program = """python
|
|
||||||
string = "racecar"
|
|
||||||
if string == string[::-1]:
|
|
||||||
print(string, "is a palindrome")
|
|
||||||
else:
|
|
||||||
print(string, "is not a palindrome")"""
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
assert tool.run(program) == "racecar is a palindrome\n"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_repl_print_python_backticks() -> None:
|
|
||||||
program = "`print('`python` is a great language.')`"
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
assert tool.run(program) == "`python` is a great language.\n"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_python_ast_repl_raise_exception() -> None:
|
|
||||||
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
|
|
||||||
program = """
|
|
||||||
import pandas as pd
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
df['Gender']
|
|
||||||
"""
|
|
||||||
tool = PythonAstREPLTool(locals={"data": data})
|
|
||||||
expected_outputs = (
|
|
||||||
"KeyError: 'Gender'",
|
|
||||||
"ModuleNotFoundError: No module named 'pandas'",
|
|
||||||
)
|
|
||||||
assert tool.run(program) in expected_outputs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_python_ast_repl_one_line_print() -> None:
|
|
||||||
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
assert tool.run(program) == "The square of 3 is 9.00\n"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_python_ast_repl_one_line_return() -> None:
|
|
||||||
arr = np.array([1, 2, 3, 4, 5])
|
|
||||||
tool = PythonAstREPLTool(locals={"arr": arr})
|
|
||||||
program = "`(arr**2).sum() # Returns sum of squares`"
|
|
||||||
assert tool.run(program) == 55
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
|
||||||
)
|
|
||||||
def test_python_ast_repl_one_line_exception() -> None:
|
|
||||||
program = "[1, 2, 3][4]"
|
|
||||||
tool = PythonAstREPLTool()
|
|
||||||
assert tool.run(program) == "IndexError: list index out of range"
|
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_input() -> None:
|
|
||||||
query = """
|
|
||||||
```
|
|
||||||
p = 5
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
expected = "p = 5"
|
|
||||||
actual = sanitize_input(query)
|
|
||||||
assert expected == actual
|
|
||||||
|
|
||||||
query = """
|
|
||||||
```python
|
|
||||||
p = 5
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
expected = "p = 5"
|
|
||||||
actual = sanitize_input(query)
|
|
||||||
assert expected == actual
|
|
||||||
|
|
||||||
query = """
|
|
||||||
p = 5
|
|
||||||
"""
|
|
||||||
expected = "p = 5"
|
|
||||||
actual = sanitize_input(query)
|
|
||||||
assert expected == actual
|
|
Loading…
Reference in New Issue