|
|
|
@ -18,6 +18,7 @@ from typing import (
|
|
|
|
|
Sequence,
|
|
|
|
|
Tuple,
|
|
|
|
|
Type,
|
|
|
|
|
TypedDict,
|
|
|
|
|
Union,
|
|
|
|
|
cast,
|
|
|
|
|
)
|
|
|
|
@ -182,6 +183,10 @@ def _convert_delta_to_message_chunk(
|
|
|
|
|
return default_class(content=content) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _FunctionCall(TypedDict):
|
|
|
|
|
name: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatOpenAI(BaseChatModel):
|
|
|
|
|
"""`OpenAI` Chat large language models API.
|
|
|
|
|
|
|
|
|
@ -632,7 +637,9 @@ class ChatOpenAI(BaseChatModel):
|
|
|
|
|
def bind_functions(
|
|
|
|
|
self,
|
|
|
|
|
functions: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
|
|
|
|
function_call: Optional[str] = None,
|
|
|
|
|
function_call: Optional[
|
|
|
|
|
Union[_FunctionCall, str, Literal["auto", "none"]]
|
|
|
|
|
] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
|
|
|
"""Bind functions (and other objects) to this chat model.
|
|
|
|
@ -658,18 +665,26 @@ class ChatOpenAI(BaseChatModel):
|
|
|
|
|
|
|
|
|
|
formatted_functions = [convert_to_openai_function(fn) for fn in functions]
|
|
|
|
|
if function_call is not None:
|
|
|
|
|
if len(formatted_functions) != 1:
|
|
|
|
|
function_call = (
|
|
|
|
|
{"name": function_call}
|
|
|
|
|
if isinstance(function_call, str)
|
|
|
|
|
and function_call not in ("auto", "none")
|
|
|
|
|
else function_call
|
|
|
|
|
)
|
|
|
|
|
if isinstance(function_call, dict) and len(formatted_functions) != 1:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"When specifying `function_call`, you must provide exactly one "
|
|
|
|
|
"function."
|
|
|
|
|
)
|
|
|
|
|
if formatted_functions[0]["name"] != function_call:
|
|
|
|
|
if (
|
|
|
|
|
isinstance(function_call, dict)
|
|
|
|
|
and formatted_functions[0]["name"] != function_call["name"]
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Function call {function_call} was specified, but the only "
|
|
|
|
|
f"provided function was {formatted_functions[0]['name']}."
|
|
|
|
|
)
|
|
|
|
|
function_call_ = {"name": function_call}
|
|
|
|
|
kwargs = {**kwargs, "function_call": function_call_}
|
|
|
|
|
kwargs = {**kwargs, "function_call": function_call}
|
|
|
|
|
return super().bind(
|
|
|
|
|
functions=formatted_functions,
|
|
|
|
|
**kwargs,
|
|
|
|
|