Allow custom base Zapier prompt (#4213)

Currently, all Zapier tools are built using the pre-written base Zapier
prompt. These small changes (that retain default behavior) will allow a
user to create a Zapier tool using the ZapierNLARunTool while providing
their own base prompt.

Their prompt must contain input fields for zapier_description and
params, checked and enforced in the tool's root validator.

An example of when this may be useful: user has several, say 10, Zapier
tools enabled. Currently, the long generic default Zapier base prompt is
attached to every single tool, using an extreme number of tokens for no
real added benefit (repeated). User prompts LLM on how to use Zapier
tools once, then overrides the base prompt.

Or: user has a few specific Zapier tools and wants to maximize their
success rate. So, user writes prompts/descriptions for those tools
specific to their use case, and provides those to the ZapierNLARunTool.

A consideration - this is the simplest way to implement this I could
think of... though ideally custom prompting would be possible at the
Toolkit level as well. For now, this should be sufficient in solving the
concerns outlined above.
This commit is contained in:
Prerit Das 2023-05-14 00:08:18 -04:00 committed by GitHub
parent e2bc836571
commit 2747ccbcf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 1 deletions

View File

@ -105,6 +105,7 @@ class ZapierNLARunAction(BaseTool):
api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper)
action_id: str action_id: str
params: Optional[dict] = None params: Optional[dict] = None
base_prompt: str = BASE_ZAPIER_TOOL_PROMPT
zapier_description: str zapier_description: str
params_schema: Dict[str, str] = Field(default_factory=dict) params_schema: Dict[str, str] = Field(default_factory=dict)
name = "" name = ""
@ -116,8 +117,17 @@ class ZapierNLARunAction(BaseTool):
params_schema = values["params_schema"] params_schema = values["params_schema"]
if "instructions" in params_schema: if "instructions" in params_schema:
del params_schema["instructions"] del params_schema["instructions"]
# Ensure base prompt (if overrided) contains necessary input fields
necessary_fields = {"{zapier_description}", "{params}"}
if not all(field in values["base_prompt"] for field in necessary_fields):
raise ValueError(
"Your custom base Zapier prompt must contain input fields for "
"{zapier_description} and {params}."
)
values["name"] = zapier_description values["name"] = zapier_description
values["description"] = BASE_ZAPIER_TOOL_PROMPT.format( values["description"] = values["base_prompt"].format(
zapier_description=zapier_description, zapier_description=zapier_description,
params=str(list(params_schema.keys())), params=str(list(params_schema.keys())),
) )

View File

@ -0,0 +1,52 @@
"""Test building the Zapier tool, not running it."""
import pytest
from langchain.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT
from langchain.tools.zapier.tool import ZapierNLARunAction
from langchain.utilities.zapier import ZapierNLAWrapper
def test_default_base_prompt() -> None:
"""Test that the default prompt is being inserted."""
tool = ZapierNLARunAction(
action_id="test",
zapier_description="test",
params_schema={"test": "test"},
api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"),
)
# Test that the base prompt was successfully assigned to the default prompt
assert tool.base_prompt == BASE_ZAPIER_TOOL_PROMPT
assert tool.description == BASE_ZAPIER_TOOL_PROMPT.format(
zapier_description="test",
params=str(list({"test": "test"}.keys())),
)
def test_custom_base_prompt() -> None:
"""Test that a custom prompt is being inserted."""
base_prompt = "Test. {zapier_description} and {params}."
tool = ZapierNLARunAction(
action_id="test",
zapier_description="test",
params_schema={"test": "test"},
base_prompt=base_prompt,
api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"),
)
# Test that the base prompt was successfully assigned to the default prompt
assert tool.base_prompt == base_prompt
assert tool.description == "Test. test and ['test']."
def test_custom_base_prompt_fail() -> None:
"""Test validating an invalid custom prompt."""
base_prompt = "Test. {zapier_description}."
with pytest.raises(ValueError):
ZapierNLARunAction(
action_id="test",
zapier_description="test",
params={"test": "test"},
base_prompt=base_prompt,
api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"),
)