[Evals] Fix function calling support (#19658)

Current implementation is overzealous in validating chat datasets

Fixes
[#langsmith-sdk:557](https://github.com/langchain-ai/langsmith-sdk/issues/557)
pull/19686/head
William FH 6 months ago committed by GitHub
parent 7e29b6061f
commit 5c41f4083e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -280,7 +280,11 @@ def _get_prompt(inputs: Dict[str, Any]) -> str:
) )
def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]: class ChatModelInput(TypedDict):
messages: List[BaseMessage]
def _get_messages(inputs: Dict[str, Any]) -> dict:
"""Get Chat Messages from inputs. """Get Chat Messages from inputs.
Args: Args:
@ -293,35 +297,29 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
""" """
if not inputs: if not inputs:
raise InputFormatError("Inputs should not be empty.") raise InputFormatError("Inputs should not be empty.")
input_copy = inputs.copy()
if "messages" in inputs: if "messages" in inputs:
single_input = inputs["messages"] input_copy["input"] = input_copy.pop("messages")
elif len(inputs) == 1: elif len(inputs) == 1:
single_input = next(iter(inputs.values())) input_copy["input"] = next(iter(inputs.values()))
else: if "input" in input_copy:
raise InputFormatError( raw_messages = input_copy["input"]
f"Chat Run expects 'messages' in inputs when example has multiple" if isinstance(raw_messages, list) and all(
f" input keys. Got {inputs}" isinstance(i, dict) for i in raw_messages
) ):
if isinstance(single_input, list) and all( raw_messages = [raw_messages]
isinstance(i, dict) for i in single_input if len(raw_messages) == 1:
): input_copy["input"] = messages_from_dict(raw_messages[0])
raw_messages = [single_input] else:
elif isinstance(single_input, list) and all( raise InputFormatError(
isinstance(i, list) for i in single_input "Batch messages not supported. Please provide a"
): " single list of messages."
raw_messages = single_input )
else: return input_copy
raise InputFormatError(
f"Chat Run expects List[dict] or List[List[dict]] values for"
f" 'messages' key input. Got {inputs}"
)
if len(raw_messages) == 1:
return messages_from_dict(raw_messages[0])
else: else:
raise InputFormatError( raise InputFormatError(
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'" f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
f" input. Got {len(raw_messages)} messages from inputs {inputs}" f" input. Got {inputs}"
) )
@ -711,9 +709,9 @@ async def _arun_llm(
), ),
) )
except InputFormatError: except InputFormatError:
messages = _get_messages(inputs) llm_inputs = _get_messages(inputs)
llm_output = await llm.ainvoke( llm_output = await llm.ainvoke(
messages, **llm_inputs,
config=RunnableConfig( config=RunnableConfig(
callbacks=callbacks, tags=tags or [], metadata=metadata or {} callbacks=callbacks, tags=tags or [], metadata=metadata or {}
), ),
@ -864,9 +862,9 @@ def _run_llm(
), ),
) )
except InputFormatError: except InputFormatError:
llm_messages = _get_messages(inputs) llm_inputs = _get_messages(inputs)
llm_output = llm.invoke( llm_output = llm.invoke(
llm_messages, **llm_inputs,
config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}), config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}),
) )
return llm_output return llm_output

Loading…
Cancel
Save