From c54d6eb5da4e4dd618e642c23953ca9cd5629127 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 1 Mar 2024 11:12:28 -0800 Subject: [PATCH] fireworks[patch]: support "any" tool_choice (#18343) per https://readme.fireworks.ai/docs/function-calling --- .../fireworks/langchain_fireworks/chat_models.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index a29a718771..cc5a862c4f 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -579,7 +579,9 @@ class ChatFireworks(BaseChatModel): self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], *, - tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "any", "none"], bool] + ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -592,9 +594,10 @@ class ChatFireworks(BaseChatModel): models, callables, and BaseTools will be automatically converted to their schema dictionary representation. tool_choice: Which tool to require the model to call. - Must be the name of the single provided function or + Must be the name of the single provided function, "auto" to automatically determine which function to call - (if any), or a dict of the form: + with the option to not call any function, "any" to enforce that some + function is called, or a dict of the form: {"type": "function", "function": {"name": <>}}. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. @@ -602,7 +605,9 @@ class ChatFireworks(BaseChatModel): formatted_tools = [convert_to_openai_tool(tool) for tool in tools] if tool_choice is not None and tool_choice: - if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")): + if isinstance(tool_choice, str) and ( + tool_choice not in ("auto", "any", "none") + ): tool_choice = {"type": "function", "function": {"name": tool_choice}} if isinstance(tool_choice, dict) and (len(formatted_tools) != 1): raise ValueError(