diff --git a/libs/partners/fireworks/langchain_fireworks/chat_models.py b/libs/partners/fireworks/langchain_fireworks/chat_models.py index a29a718771..cc5a862c4f 100644 --- a/libs/partners/fireworks/langchain_fireworks/chat_models.py +++ b/libs/partners/fireworks/langchain_fireworks/chat_models.py @@ -579,7 +579,9 @@ class ChatFireworks(BaseChatModel): self, tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], *, - tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "any", "none"], bool] + ] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tool-like objects to this chat model. @@ -592,9 +594,10 @@ class ChatFireworks(BaseChatModel): models, callables, and BaseTools will be automatically converted to their schema dictionary representation. tool_choice: Which tool to require the model to call. - Must be the name of the single provided function or + Must be the name of the single provided function, "auto" to automatically determine which function to call - (if any), or a dict of the form: + with the option to not call any function, "any" to enforce that some + function is called, or a dict of the form: {"type": "function", "function": {"name": <>}}. **kwargs: Any additional parameters to pass to the :class:`~langchain.runnable.Runnable` constructor. @@ -602,7 +605,9 @@ class ChatFireworks(BaseChatModel): formatted_tools = [convert_to_openai_tool(tool) for tool in tools] if tool_choice is not None and tool_choice: - if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")): + if isinstance(tool_choice, str) and ( + tool_choice not in ("auto", "any", "none") + ): tool_choice = {"type": "function", "function": {"name": tool_choice}} if isinstance(tool_choice, dict) and (len(formatted_tools) != 1): raise ValueError(