diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 35c7b8a7e8..2559c8eb71 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -280,7 +280,11 @@ def _get_prompt(inputs: Dict[str, Any]) -> str: ) -def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]: +class ChatModelInput(TypedDict): + messages: List[BaseMessage] + + +def _get_messages(inputs: Dict[str, Any]) -> dict: """Get Chat Messages from inputs. Args: @@ -293,35 +297,29 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]: """ if not inputs: raise InputFormatError("Inputs should not be empty.") - + input_copy = inputs.copy() if "messages" in inputs: - single_input = inputs["messages"] + input_copy["input"] = input_copy.pop("messages") elif len(inputs) == 1: - single_input = next(iter(inputs.values())) - else: - raise InputFormatError( - f"Chat Run expects 'messages' in inputs when example has multiple" - f" input keys. Got {inputs}" - ) - if isinstance(single_input, list) and all( - isinstance(i, dict) for i in single_input - ): - raw_messages = [single_input] - elif isinstance(single_input, list) and all( - isinstance(i, list) for i in single_input - ): - raw_messages = single_input - else: - raise InputFormatError( - f"Chat Run expects List[dict] or List[List[dict]] values for" - f" 'messages' key input. Got {inputs}" - ) - if len(raw_messages) == 1: - return messages_from_dict(raw_messages[0]) + input_copy["input"] = next(iter(inputs.values())) + if "input" in input_copy: + raw_messages = input_copy["input"] + if isinstance(raw_messages, list) and all( + isinstance(i, dict) for i in raw_messages + ): + raw_messages = [raw_messages] + if len(raw_messages) == 1: + input_copy["input"] = messages_from_dict(raw_messages[0]) + else: + raise InputFormatError( + "Batch messages not supported. Please provide a" + " single list of messages." + ) + return input_copy else: raise InputFormatError( f"Chat Run expects single List[dict] or List[List[dict]] 'messages'" - f" input. Got {len(raw_messages)} messages from inputs {inputs}" + f" input. Got {inputs}" ) @@ -711,9 +709,9 @@ async def _arun_llm( ), ) except InputFormatError: - messages = _get_messages(inputs) + llm_inputs = _get_messages(inputs) llm_output = await llm.ainvoke( - messages, + **llm_inputs, config=RunnableConfig( callbacks=callbacks, tags=tags or [], metadata=metadata or {} ), @@ -864,9 +862,9 @@ def _run_llm( ), ) except InputFormatError: - llm_messages = _get_messages(inputs) + llm_inputs = _get_messages(inputs) llm_output = llm.invoke( - llm_messages, + **llm_inputs, config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}), ) return llm_output