support functions (#6099)

This commit is contained in:
Harrison Chase 2023-06-13 10:32:58 -07:00 committed by GitHub
parent ee3d0513ad
commit 292accde2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,12 +92,17 @@ async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
return await _completion_with_retry(**kwargs) return await _completion_with_retry(**kwargs)
def _convert_dict_to_message(_dict: dict) -> BaseMessage: def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"] role = _dict["role"]
if role == "user": if role == "user":
return HumanMessage(content=_dict["content"]) return HumanMessage(content=_dict["content"])
elif role == "assistant": elif role == "assistant":
return AIMessage(content=_dict["content"]) content = _dict["content"] or "" # OpenAI returns None for tool invocations
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
return AIMessage(content=content, additional_kwargs=additional_kwargs)
elif role == "system": elif role == "system":
return SystemMessage(content=_dict["content"]) return SystemMessage(content=_dict["content"])
else: else:
@ -111,6 +116,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": "user", "content": message.content} message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content} message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
else: else: