|
|
|
@ -280,7 +280,11 @@ def _get_prompt(inputs: Dict[str, Any]) -> str:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
|
|
|
|
|
class ChatModelInput(TypedDict):
|
|
|
|
|
messages: List[BaseMessage]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_messages(inputs: Dict[str, Any]) -> dict:
|
|
|
|
|
"""Get Chat Messages from inputs.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -293,35 +297,29 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
|
|
|
|
|
"""
|
|
|
|
|
if not inputs:
|
|
|
|
|
raise InputFormatError("Inputs should not be empty.")
|
|
|
|
|
|
|
|
|
|
input_copy = inputs.copy()
|
|
|
|
|
if "messages" in inputs:
|
|
|
|
|
single_input = inputs["messages"]
|
|
|
|
|
input_copy["input"] = input_copy.pop("messages")
|
|
|
|
|
elif len(inputs) == 1:
|
|
|
|
|
single_input = next(iter(inputs.values()))
|
|
|
|
|
else:
|
|
|
|
|
raise InputFormatError(
|
|
|
|
|
f"Chat Run expects 'messages' in inputs when example has multiple"
|
|
|
|
|
f" input keys. Got {inputs}"
|
|
|
|
|
)
|
|
|
|
|
if isinstance(single_input, list) and all(
|
|
|
|
|
isinstance(i, dict) for i in single_input
|
|
|
|
|
):
|
|
|
|
|
raw_messages = [single_input]
|
|
|
|
|
elif isinstance(single_input, list) and all(
|
|
|
|
|
isinstance(i, list) for i in single_input
|
|
|
|
|
):
|
|
|
|
|
raw_messages = single_input
|
|
|
|
|
else:
|
|
|
|
|
raise InputFormatError(
|
|
|
|
|
f"Chat Run expects List[dict] or List[List[dict]] values for"
|
|
|
|
|
f" 'messages' key input. Got {inputs}"
|
|
|
|
|
)
|
|
|
|
|
if len(raw_messages) == 1:
|
|
|
|
|
return messages_from_dict(raw_messages[0])
|
|
|
|
|
input_copy["input"] = next(iter(inputs.values()))
|
|
|
|
|
if "input" in input_copy:
|
|
|
|
|
raw_messages = input_copy["input"]
|
|
|
|
|
if isinstance(raw_messages, list) and all(
|
|
|
|
|
isinstance(i, dict) for i in raw_messages
|
|
|
|
|
):
|
|
|
|
|
raw_messages = [raw_messages]
|
|
|
|
|
if len(raw_messages) == 1:
|
|
|
|
|
input_copy["input"] = messages_from_dict(raw_messages[0])
|
|
|
|
|
else:
|
|
|
|
|
raise InputFormatError(
|
|
|
|
|
"Batch messages not supported. Please provide a"
|
|
|
|
|
" single list of messages."
|
|
|
|
|
)
|
|
|
|
|
return input_copy
|
|
|
|
|
else:
|
|
|
|
|
raise InputFormatError(
|
|
|
|
|
f"Chat Run expects single List[dict] or List[List[dict]] 'messages'"
|
|
|
|
|
f" input. Got {len(raw_messages)} messages from inputs {inputs}"
|
|
|
|
|
f" input. Got {inputs}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -711,9 +709,9 @@ async def _arun_llm(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
except InputFormatError:
|
|
|
|
|
messages = _get_messages(inputs)
|
|
|
|
|
llm_inputs = _get_messages(inputs)
|
|
|
|
|
llm_output = await llm.ainvoke(
|
|
|
|
|
messages,
|
|
|
|
|
**llm_inputs,
|
|
|
|
|
config=RunnableConfig(
|
|
|
|
|
callbacks=callbacks, tags=tags or [], metadata=metadata or {}
|
|
|
|
|
),
|
|
|
|
@ -864,9 +862,9 @@ def _run_llm(
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
except InputFormatError:
|
|
|
|
|
llm_messages = _get_messages(inputs)
|
|
|
|
|
llm_inputs = _get_messages(inputs)
|
|
|
|
|
llm_output = llm.invoke(
|
|
|
|
|
llm_messages,
|
|
|
|
|
**llm_inputs,
|
|
|
|
|
config=RunnableConfig(callbacks=callbacks, metadata=metadata or {}),
|
|
|
|
|
)
|
|
|
|
|
return llm_output
|
|
|
|
|