From 9927a4866d9687dea893765ce07e1195295ce4e5 Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Mon, 12 Aug 2024 19:11:43 +0500 Subject: [PATCH] [Community] - Added bind_tools and with_structured_output for ChatZhipuAI (#23887) - **Description:** This PR implements the `bind_tool` functionality for ChatZhipuAI as requested by the user. ChatZhipuAI models support tool calling according to the `OpenAI` tool format, as outlined in their official documentation [here](https://open.bigmodel.cn/dev/api#glm-4). - **Issue:** ##23868 --------- Co-authored-by: ccurme --- .../chat_models/zhipuai.py | 202 +++++++++++++++++- 1 file changed, 201 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index b551b21ff8..03d8583987 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -7,12 +7,25 @@ import logging import time from collections.abc import AsyncIterator, Iterator from contextlib import asynccontextmanager, contextmanager -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from operator import itemgetter +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + Union, +) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, @@ -30,9 +43,17 @@ from langchain_core.messages import ( SystemMessage, SystemMessageChunk, ) +from langchain_core.output_parsers.base import OutputParserLike +from langchain_core.output_parsers.openai_tools import ( + JsonOutputKeyToolsParser, + PydanticToolsParser, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.tools import BaseTool from langchain_core.utils import get_from_dict_or_env +from langchain_core.utils.function_calling import convert_to_openai_tool logger = logging.getLogger(__name__) @@ -40,6 +61,10 @@ API_TOKEN_TTL_SECONDS = 3 * 60 ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions" +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and issubclass(obj, BaseModel) + + @contextmanager def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: """Context manager for connecting to an SSE stream. @@ -587,3 +612,178 @@ class ChatZhipuAI(BaseChatModel): if finish_reason is not None: break + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "any", "none"], bool] + ] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + Args: + tools: A list of tool definitions to bind to this chat model. + Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic + models, callables, and BaseTools will be automatically converted to + their schema dictionary representation. + tool_choice: Currently this can only be auto for this chat model. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + if self.model_name == "glm-4v": + raise ValueError("glm-4v currently does not support tool calling") + + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + if tool_choice and tool_choice != "auto": + raise ValueError("ChatZhipuAI currently only supports `auto` tool choice") + elif tool_choice and tool_choice == "auto": + kwargs["tool_choice"] = tool_choice + return self.bind(tools=formatted_tools, **kwargs) + + def with_structured_output( + self, + schema: Optional[Union[Dict, Type[BaseModel]]] = None, + *, + method: Literal["function_calling", "json_mode"] = "function_calling", + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. If + `method` is "function_calling" and `schema` is a dict, then the dict + must match the OpenAI function-calling spec. + method: The method for steering model generation, either "function_calling" + or "json_mode". ZhipuAI only supports "function_calling" which + converts the schema to a OpenAI function and the model will make use of the + function-calling API. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input and returns as output: + + If include_raw is True then a dict with keys: + raw: BaseMessage + parsed: Optional[_DictOrPydantic] + parsing_error: Optional[BaseException] + + If include_raw is False then just _DictOrPydantic is returned, + where _DictOrPydantic depends on the schema: + + If schema is a Pydantic class then _DictOrPydantic is the Pydantic + class. + + If schema is a dict then _DictOrPydantic is a dict. + + Example: Function-calling, Pydantic schema (method="function_calling", include_raw=False): + .. code-block:: python + + from langchain_community.chat_models import ChatZhipuAI + from langchain_core.pydantic_v1 import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = ChatZhipuAI(temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> AnswerWithJustification( + # answer='A pound of bricks and a pound of feathers weigh the same.' + # justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same." + # ) + + Example: Function-calling, Pydantic schema (method="function_calling", include_raw=True): + .. code-block:: python + + from langchain_community.chat_models import ChatZhipuAI + from langchain_core.pydantic_v1 import BaseModel + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + llm = ChatZhipuAI(temperature=0) + structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_01htjn3cspevxbqc1d7nkk8wab', 'function': {'arguments': '{"answer": "A pound of bricks and a pound of feathers weigh the same.", "justification": "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The \'pound\' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", "unit": "pounds"}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}, id='run-456beee6-65f6-4e80-88af-a6065480822c-0'), + # 'parsed': AnswerWithJustification(answer='A pound of bricks and a pound of feathers weigh the same.', justification="Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same."), + # 'parsing_error': None + # } + + Example: Function-calling, dict schema (method="function_calling", include_raw=False): + .. code-block:: python + + from langchain_community.chat_models import ChatZhipuAI + from langchain_core.pydantic_v1 import BaseModel + from langchain_core.utils.function_calling import convert_to_openai_tool + + class AnswerWithJustification(BaseModel): + '''An answer to the user question along with justification for the answer.''' + answer: str + justification: str + + dict_schema = convert_to_openai_tool(AnswerWithJustification) + llm = ChatZhipuAI(temperature=0) + structured_llm = llm.with_structured_output(dict_schema) + + structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + # -> { + # 'answer': 'A pound of bricks and a pound of feathers weigh the same.', + # 'justification': "Both a pound of bricks and a pound of feathers have been defined to have the same weight. The 'pound' is a unit of weight, so any two things that are described as weighing a pound will weigh the same.", 'unit': 'pounds'} + # } + + """ # noqa: E501 + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = _is_pydantic_class(schema) + if method == "function_calling": + if schema is None: + raise ValueError( + "schema must be specified when method is 'function_calling'. " + "Received None." + ) + tool_name = convert_to_openai_tool(schema)["function"]["name"] + llm = self.bind_tools([schema], tool_choice="auto") + if is_pydantic_schema: + output_parser: OutputParserLike = PydanticToolsParser( + tools=[schema], # type: ignore[list-item] + first_tool_only=True, # type: ignore[list-item] + ) + else: + output_parser = JsonOutputKeyToolsParser( + key_name=tool_name, first_tool_only=True + ) + else: + raise ValueError( + f"""Unrecognized method argument. Expected 'function_calling'. + Received: '{method}'""" + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser