From 31790d15ec2cc2d9857eba2945b7baa97ac97494 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Thu, 25 Jan 2024 13:47:44 -0800 Subject: [PATCH] openai[patch]: accept function_call dict in bind_functions (#16483) Confusing that you can't pass in a dict --- .../langchain_openai/chat_models/base.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 5989c3ccfa..3c5bfe35a3 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -18,6 +18,7 @@ from typing import ( Sequence, Tuple, Type, + TypedDict, Union, cast, ) @@ -182,6 +183,10 @@ def _convert_delta_to_message_chunk( return default_class(content=content) # type: ignore +class _FunctionCall(TypedDict): + name: str + + class ChatOpenAI(BaseChatModel): """`OpenAI` Chat large language models API. @@ -632,7 +637,9 @@ class ChatOpenAI(BaseChatModel): def bind_functions( self, functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], - function_call: Optional[str] = None, + function_call: Optional[ + Union[_FunctionCall, str, Literal["auto", "none"]] + ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind functions (and other objects) to this chat model. @@ -658,18 +665,26 @@ class ChatOpenAI(BaseChatModel): formatted_functions = [convert_to_openai_function(fn) for fn in functions] if function_call is not None: - if len(formatted_functions) != 1: + function_call = ( + {"name": function_call} + if isinstance(function_call, str) + and function_call not in ("auto", "none") + else function_call + ) + if isinstance(function_call, dict) and len(formatted_functions) != 1: raise ValueError( "When specifying `function_call`, you must provide exactly one " "function." ) - if formatted_functions[0]["name"] != function_call: + if ( + isinstance(function_call, dict) + and formatted_functions[0]["name"] != function_call["name"] + ): raise ValueError( f"Function call {function_call} was specified, but the only " f"provided function was {formatted_functions[0]['name']}." ) - function_call_ = {"name": function_call} - kwargs = {**kwargs, "function_call": function_call_} + kwargs = {**kwargs, "function_call": function_call} return super().bind( functions=formatted_functions, **kwargs,