openai[patch]: accept function_call dict in bind_functions (#16483)

Confusing that you can't pass in a dict
pull/16590/head
Bagatur 8 months ago committed by GitHub
parent db80832e4f
commit 31790d15ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -18,6 +18,7 @@ from typing import (
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
TypedDict,
Union, Union,
cast, cast,
) )
@ -182,6 +183,10 @@ def _convert_delta_to_message_chunk(
return default_class(content=content) # type: ignore return default_class(content=content) # type: ignore
class _FunctionCall(TypedDict):
name: str
class ChatOpenAI(BaseChatModel): class ChatOpenAI(BaseChatModel):
"""`OpenAI` Chat large language models API. """`OpenAI` Chat large language models API.
@ -632,7 +637,9 @@ class ChatOpenAI(BaseChatModel):
def bind_functions( def bind_functions(
self, self,
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
function_call: Optional[str] = None, function_call: Optional[
Union[_FunctionCall, str, Literal["auto", "none"]]
] = None,
**kwargs: Any, **kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]: ) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind functions (and other objects) to this chat model. """Bind functions (and other objects) to this chat model.
@ -658,18 +665,26 @@ class ChatOpenAI(BaseChatModel):
formatted_functions = [convert_to_openai_function(fn) for fn in functions] formatted_functions = [convert_to_openai_function(fn) for fn in functions]
if function_call is not None: if function_call is not None:
if len(formatted_functions) != 1: function_call = (
{"name": function_call}
if isinstance(function_call, str)
and function_call not in ("auto", "none")
else function_call
)
if isinstance(function_call, dict) and len(formatted_functions) != 1:
raise ValueError( raise ValueError(
"When specifying `function_call`, you must provide exactly one " "When specifying `function_call`, you must provide exactly one "
"function." "function."
) )
if formatted_functions[0]["name"] != function_call: if (
isinstance(function_call, dict)
and formatted_functions[0]["name"] != function_call["name"]
):
raise ValueError( raise ValueError(
f"Function call {function_call} was specified, but the only " f"Function call {function_call} was specified, but the only "
f"provided function was {formatted_functions[0]['name']}." f"provided function was {formatted_functions[0]['name']}."
) )
function_call_ = {"name": function_call} kwargs = {**kwargs, "function_call": function_call}
kwargs = {**kwargs, "function_call": function_call_}
return super().bind( return super().bind(
functions=formatted_functions, functions=formatted_functions,
**kwargs, **kwargs,

Loading…
Cancel
Save