diff --git a/langchain/tools/zapier/tool.py b/langchain/tools/zapier/tool.py index f68a3562f7..cb1fc29544 100644 --- a/langchain/tools/zapier/tool.py +++ b/langchain/tools/zapier/tool.py @@ -105,6 +105,7 @@ class ZapierNLARunAction(BaseTool): api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) action_id: str params: Optional[dict] = None + base_prompt: str = BASE_ZAPIER_TOOL_PROMPT zapier_description: str params_schema: Dict[str, str] = Field(default_factory=dict) name = "" @@ -116,8 +117,17 @@ class ZapierNLARunAction(BaseTool): params_schema = values["params_schema"] if "instructions" in params_schema: del params_schema["instructions"] + + # Ensure base prompt (if overrided) contains necessary input fields + necessary_fields = {"{zapier_description}", "{params}"} + if not all(field in values["base_prompt"] for field in necessary_fields): + raise ValueError( + "Your custom base Zapier prompt must contain input fields for " + "{zapier_description} and {params}." + ) + values["name"] = zapier_description - values["description"] = BASE_ZAPIER_TOOL_PROMPT.format( + values["description"] = values["base_prompt"].format( zapier_description=zapier_description, params=str(list(params_schema.keys())), ) diff --git a/tests/unit_tests/tools/test_zapier.py b/tests/unit_tests/tools/test_zapier.py new file mode 100644 index 0000000000..a4b60be965 --- /dev/null +++ b/tests/unit_tests/tools/test_zapier.py @@ -0,0 +1,52 @@ +"""Test building the Zapier tool, not running it.""" +import pytest + +from langchain.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT +from langchain.tools.zapier.tool import ZapierNLARunAction +from langchain.utilities.zapier import ZapierNLAWrapper + + +def test_default_base_prompt() -> None: + """Test that the default prompt is being inserted.""" + tool = ZapierNLARunAction( + action_id="test", + zapier_description="test", + params_schema={"test": "test"}, + api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"), + ) + + # Test that the base prompt was successfully assigned to the default prompt + assert tool.base_prompt == BASE_ZAPIER_TOOL_PROMPT + assert tool.description == BASE_ZAPIER_TOOL_PROMPT.format( + zapier_description="test", + params=str(list({"test": "test"}.keys())), + ) + + +def test_custom_base_prompt() -> None: + """Test that a custom prompt is being inserted.""" + base_prompt = "Test. {zapier_description} and {params}." + tool = ZapierNLARunAction( + action_id="test", + zapier_description="test", + params_schema={"test": "test"}, + base_prompt=base_prompt, + api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"), + ) + + # Test that the base prompt was successfully assigned to the default prompt + assert tool.base_prompt == base_prompt + assert tool.description == "Test. test and ['test']." + + +def test_custom_base_prompt_fail() -> None: + """Test validating an invalid custom prompt.""" + base_prompt = "Test. {zapier_description}." + with pytest.raises(ValueError): + ZapierNLARunAction( + action_id="test", + zapier_description="test", + params={"test": "test"}, + base_prompt=base_prompt, + api_wrapper=ZapierNLAWrapper(zapier_nla_api_key="test"), + )