diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index dc5003e552..21b31650ff 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -67,13 +67,27 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): **{k: (self.input_types.get(k, str), None) for k in self.input_variables}, ) + def _format_prompt_with_error_handling(self, inner_input: Dict) -> PromptValue: + try: + input_dict = {key: inner_input[key] for key in self.input_variables} + except TypeError as e: + raise TypeError( + f"Expected mapping type as input to {self.__class__.__name__}. " + f"Received {type(inner_input)}." + ) from e + except KeyError as e: + raise KeyError( + f"Input to {self.__class__.__name__} is missing variable {e}. " + f" Expected: {self.input_variables}" + f" Received: {list(inner_input.keys())}" + ) from e + return self.format_prompt(**input_dict) + def invoke( self, input: Dict, config: Optional[RunnableConfig] = None ) -> PromptValue: return self._call_with_config( - lambda inner_input: self.format_prompt( - **{key: inner_input[key] for key in self.input_variables} - ), + self._format_prompt_with_error_handling, input, config, run_type="prompt",