From 2747ccbcf16305a34f8c5e97f44f3f5bef108b5f Mon Sep 17 00:00:00 2001 From: Prerit Das Date: Sun, 14 May 2023 00:08:18 -0400 Subject: [PATCH] Allow custom base Zapier prompt (#4213) Currently, all Zapier tools are built using the pre-written base Zapier prompt. These small changes (that retain default behavior) will allow a user to create a Zapier tool using the ZapierNLARunTool while providing their own base prompt. Their prompt must contain input fields for zapier_description and params, checked and enforced in the tool's root validator. An example of when this may be useful: user has several, say 10, Zapier tools enabled. Currently, the long generic default Zapier base prompt is attached to every single tool, using an extreme number of tokens for no real added benefit (repeated). User prompts LLM on how to use Zapier tools once, then overrides the base prompt. Or: user has a few specific Zapier tools and wants to maximize their success rate. So, user writes prompts/descriptions for those tools specific to their use case, and provides those to the ZapierNLARunTool. A consideration - this is the simplest way to implement this I could think of... though ideally custom prompting would be possible at the Toolkit level as well. For now, this should be sufficient in solving the concerns outlined above. --- langchain/tools/zapier/tool.py | 12 ++++++- tests/unit_tests/tools/test_zapier.py | 52 +++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/tools/test_zapier.py diff --git a/langchain/tools/zapier/tool.py b/langchain/tools/zapier/tool.py index f68a3562..cb1fc295 100644 --- a/langchain/tools/zapier/tool.py +++ b/langchain/tools/zapier/tool.py @@ -105,6 +105,7 @@ class ZapierNLARunAction(BaseTool): api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) action_id: str params: Optional[dict] = None + base_prompt: str = BASE_ZAPIER_TOOL_PROMPT zapier_description: str params_schema: Dict[str, str] = Field(default_factory=dict) name = "" @@ -116,8 +117,17 @@ class ZapierNLARunAction(BaseTool): params_schema = values["params_schema"] if "instructions" in params_schema: del params_schema["instructions"] + + # Ensure base prompt (if overrided) contains necessary input fields + necessary_fields = {"{zapier_description}", "{params}"} + if not all(field in values["base_prompt"] for field in necessary_fields): + raise ValueError( + "Your custom base Zapier prompt must contain input fields for " + "{zapier_description} and {params}." + ) + values["name"] = zapier_description - values["description"] = BASE_ZAPIER_TOOL_PROMPT.format( + values["description"] = values["base_prompt"].format( zapier_description=zapier_description, params=str(list(params_schema.keys())), ) diff --git a/tests/unit_tests/tools/test_zapier.py b/tests/unit_tests/tools/test_zapier.py new file mode 100644 index 00000000..a4b60be9 --- /dev/null +++ b/tests/unit_tests/tools/test_zapier.py @@ -0,0 +1,52 @@ +"""Test building the Zapier tool, not running it.""" +import pytest + +from langchain.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT +from langchain.tools.zapier.tool import ZapierNLARunAction +from langchain.utilities.zapier import ZapierNLAWrapper + + +def test_default_base_prompt() -> None: + """Test that the default prompt is being inserted.""" + tool = ZapierNLARunAction( + action_id="test", + zapier_description="test", + params_schema={"test": "test"}, + api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"), + ) + + # Test that the base prompt was successfully assigned to the default prompt + assert tool.base_prompt == BASE_ZAPIER_TOOL_PROMPT + assert tool.description == BASE_ZAPIER_TOOL_PROMPT.format( + zapier_description="test", + params=str(list({"test": "test"}.keys())), + ) + + +def test_custom_base_prompt() -> None: + """Test that a custom prompt is being inserted.""" + base_prompt = "Test. {zapier_description} and {params}." + tool = ZapierNLARunAction( + action_id="test", + zapier_description="test", + params_schema={"test": "test"}, + base_prompt=base_prompt, + api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"), + ) + + # Test that the base prompt was successfully assigned to the default prompt + assert tool.base_prompt == base_prompt + assert tool.description == "Test. test and ['test']." + + +def test_custom_base_prompt_fail() -> None: + """Test validating an invalid custom prompt.""" + base_prompt = "Test. {zapier_description}." + with pytest.raises(ValueError): + ZapierNLARunAction( + action_id="test", + zapier_description="test", + params={"test": "test"}, + base_prompt=base_prompt, + api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"), + )