mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
openai[patch]: accept function_call dict in bind_functions (#16483)
Confusing that you can't pass in a dict
This commit is contained in:
parent
db80832e4f
commit
31790d15ec
@ -18,6 +18,7 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -182,6 +183,10 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content) # type: ignore
|
||||
|
||||
|
||||
class _FunctionCall(TypedDict):
|
||||
name: str
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatModel):
|
||||
"""`OpenAI` Chat large language models API.
|
||||
|
||||
@ -632,7 +637,9 @@ class ChatOpenAI(BaseChatModel):
|
||||
def bind_functions(
|
||||
self,
|
||||
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
function_call: Optional[str] = None,
|
||||
function_call: Optional[
|
||||
Union[_FunctionCall, str, Literal["auto", "none"]]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind functions (and other objects) to this chat model.
|
||||
@ -658,18 +665,26 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
||||
if function_call is not None:
|
||||
if len(formatted_functions) != 1:
|
||||
function_call = (
|
||||
{"name": function_call}
|
||||
if isinstance(function_call, str)
|
||||
and function_call not in ("auto", "none")
|
||||
else function_call
|
||||
)
|
||||
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
||||
raise ValueError(
|
||||
"When specifying `function_call`, you must provide exactly one "
|
||||
"function."
|
||||
)
|
||||
if formatted_functions[0]["name"] != function_call:
|
||||
if (
|
||||
isinstance(function_call, dict)
|
||||
and formatted_functions[0]["name"] != function_call["name"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Function call {function_call} was specified, but the only "
|
||||
f"provided function was {formatted_functions[0]['name']}."
|
||||
)
|
||||
function_call_ = {"name": function_call}
|
||||
kwargs = {**kwargs, "function_call": function_call_}
|
||||
kwargs = {**kwargs, "function_call": function_call}
|
||||
return super().bind(
|
||||
functions=formatted_functions,
|
||||
**kwargs,
|
||||
|
Loading…
Reference in New Issue
Block a user