openai[patch]: accept function_call dict in bind_functions (#16483)

Confusing that you can't pass in a dict
This commit is contained in:
Bagatur 2024-01-25 13:47:44 -08:00 committed by GitHub
parent db80832e4f
commit 31790d15ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,6 +18,7 @@ from typing import (
Sequence,
Tuple,
Type,
TypedDict,
Union,
cast,
)
@ -182,6 +183,10 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) # type: ignore
class _FunctionCall(TypedDict):
name: str
class ChatOpenAI(BaseChatModel):
"""`OpenAI` Chat large language models API.
@ -632,7 +637,9 @@ class ChatOpenAI(BaseChatModel):
def bind_functions(
self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
function_call: Optional[str] = None,
function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind functions (and other objects) to this chat model.
@ -658,18 +665,26 @@ class ChatOpenAI(BaseChatModel):
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
if function_call is not None:
if len(formatted_functions) != 1:
function_call = (
{"name": function_call}
if isinstance(function_call, str)
and function_call not in ("auto", "none")
else function_call
)
if isinstance(function_call, dict) and len(formatted_functions) != 1:
raise ValueError(
"When specifying `function_call`, you must provide exactly one "
"function."
)
if formatted_functions[0]["name"] != function_call:
if (
isinstance(function_call, dict)
and formatted_functions[0]["name"] != function_call["name"]
):
raise ValueError(
f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}."
)
function_call_ = {"name": function_call}
kwargs = {**kwargs, "function_call": function_call_}
kwargs = {**kwargs, "function_call": function_call}
return super().bind(
functions=formatted_functions,
**kwargs,