diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index 62c8160c52..0caaed131f 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -37,7 +37,9 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC): def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue: return self._call_with_config( - lambda inner_input: self.format_prompt(**inner_input), + lambda inner_input: self.format_prompt( + **{key: inner_input[key] for key in self.input_variables} + ), input, config, run_type="prompt", diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 45fbe1ca40..61f74eb22b 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -369,6 +369,24 @@ async def test_prompt() -> None: ] == [expected] +def test_prompt_template_params() -> None: + prompt = ChatPromptTemplate.from_template( + "Respond to the following question: {question}" + ) + result = prompt.invoke( + { + "question": "test", + "topic": "test", + } + ) + assert result == ChatPromptValue( + messages=[HumanMessage(content="Respond to the following question: test")] + ) + + with pytest.raises(KeyError): + prompt.invoke({}) + + @pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_chat_model(