From f60f59d69f25f746f3494bb8f7d16168ae4b079a Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 18 Jan 2024 12:17:40 -0800 Subject: [PATCH] google-vertexai[patch]: Harrison/vertex function calling (#16223) Co-authored-by: Erick Friis --- libs/partners/google-vertexai/Makefile | 4 +- .../langchain_google_vertexai/__init__.py | 4 + .../langchain_google_vertexai/chains.py | 111 ++++++++++++++++++ .../functions_utils.py | 104 +++++++++++++++- .../tests/unit_tests/test_imports.py | 2 + 5 files changed, 220 insertions(+), 5 deletions(-) create mode 100644 libs/partners/google-vertexai/langchain_google_vertexai/chains.py diff --git a/libs/partners/google-vertexai/Makefile b/libs/partners/google-vertexai/Makefile index a1a4607ae6..29214d4bbc 100644 --- a/libs/partners/google-vertexai/Makefile +++ b/libs/partners/google-vertexai/Makefile @@ -6,9 +6,9 @@ all: help # Define a variable for the test file path. TEST_FILE ?= tests/unit_tests/ -test_integration: TEST_FILE = tests/integration_tests/ +integration_tests: TEST_FILE = tests/integration_tests/ -test test_integration: +test integration_tests: poetry run pytest $(TEST_FILE) tests: diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/__init__.py b/libs/partners/google-vertexai/langchain_google_vertexai/__init__.py index ba97adf52e..be365bde4c 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/__init__.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/__init__.py @@ -1,6 +1,8 @@ from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory +from langchain_google_vertexai.chains import create_structured_runnable from langchain_google_vertexai.chat_models import ChatVertexAI from langchain_google_vertexai.embeddings import VertexAIEmbeddings +from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden __all__ = [ @@ -10,4 +12,6 @@ __all__ = [ "VertexAIModelGarden", "HarmBlockThreshold", "HarmCategory", + "PydanticFunctionsOutputParser", + "create_structured_runnable", ] diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/chains.py b/libs/partners/google-vertexai/langchain_google_vertexai/chains.py new file mode 100644 index 0000000000..9b11794b30 --- /dev/null +++ b/libs/partners/google-vertexai/langchain_google_vertexai/chains.py @@ -0,0 +1,111 @@ +from typing import ( + Dict, + Optional, + Sequence, + Type, + Union, +) + +from langchain_core.output_parsers import ( + BaseGenerationOutputParser, + BaseOutputParser, +) +from langchain_core.prompts import BasePromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import Runnable + +from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser + + +def get_output_parser( + functions: Sequence[Type[BaseModel]], +) -> Union[BaseOutputParser, BaseGenerationOutputParser]: + """Get the appropriate function output parser given the user functions. + + Args: + functions: Sequence where element is a dictionary, a pydantic.BaseModel class, + or a Python function. If a dictionary is passed in, it is assumed to + already be a valid OpenAI function. + + Returns: + A PydanticFunctionsOutputParser + """ + function_names = [f.__name__ for f in functions] + if len(functions) > 1: + pydantic_schema: Union[Dict, Type[BaseModel]] = { + name: fn for name, fn in zip(function_names, functions) + } + else: + pydantic_schema = functions[0] + output_parser: Union[ + BaseOutputParser, BaseGenerationOutputParser + ] = PydanticFunctionsOutputParser(pydantic_schema=pydantic_schema) + return output_parser + + +def create_structured_runnable( + function: Union[Type[BaseModel], Sequence[Type[BaseModel]]], + llm: Runnable, + *, + prompt: Optional[BasePromptTemplate] = None, +) -> Runnable: + """Create a runnable sequence that uses OpenAI functions. + + Args: + function: Either a single pydantic.BaseModel class or a sequence + of pydantic.BaseModels classes. + For best results, pydantic.BaseModels + should have descriptions of the parameters. + llm: Language model to use, + assumed to support the Google Vertex function-calling API. + prompt: BasePromptTemplate to pass to the model. + + Returns: + A runnable sequence that will pass in the given functions to the model when run. + + Example: + .. code-block:: python + + from typing import Optional + + from langchain_google_vertexai import ChatVertexAI, create_structured_runnable + from langchain_core.prompts import ChatPromptTemplate + from langchain_core.pydantic_v1 import BaseModel, Field + + + class RecordPerson(BaseModel): + \"\"\"Record some identifying information about a person.\"\"\" + + name: str = Field(..., description="The person's name") + age: int = Field(..., description="The person's age") + fav_food: Optional[str] = Field(None, description="The person's favorite food") + + + class RecordDog(BaseModel): + \"\"\"Record some identifying information about a dog.\"\"\" + + name: str = Field(..., description="The dog's name") + color: str = Field(..., description="The dog's color") + fav_food: Optional[str] = Field(None, description="The dog's favorite food") + + + llm = ChatVertexAI(model_name="gemini-pro") + prompt = ChatPromptTemplate.from_template(\"\"\" + You are a world class algorithm for recording entities. + Make calls to the relevant function to record the entities in the following input: {input} + Tip: Make sure to answer in the correct format\"\"\" + ) + chain = create_structured_runnable([RecordPerson, RecordDog], llm, prompt=prompt) + chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"}) + # -> RecordDog(name="Harry", color="brown", fav_food="chicken") + """ # noqa: E501 + if not function: + raise ValueError("Need to pass in at least one function. Received zero.") + functions = function if isinstance(function, Sequence) else [function] + output_parser = get_output_parser(functions) + llm_with_functions = llm.bind(functions=functions) + if prompt is None: + initial_chain = llm_with_functions + else: + initial_chain = prompt | llm_with_functions + return initial_chain | output_parser diff --git a/libs/partners/google-vertexai/langchain_google_vertexai/functions_utils.py b/libs/partners/google-vertexai/langchain_google_vertexai/functions_utils.py index 8e6aed3da1..304e6a85c1 100644 --- a/libs/partners/google-vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/partners/google-vertexai/langchain_google_vertexai/functions_utils.py @@ -1,5 +1,10 @@ -from typing import List +import json +from typing import Dict, List, Type, Union +from langchain_core.exceptions import OutputParserException +from langchain_core.output_parsers import BaseOutputParser +from langchain_core.outputs import ChatGeneration, Generation +from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import Tool from langchain_core.utils.function_calling import FunctionDescription from langchain_core.utils.json_schema import dereference_refs @@ -11,6 +16,29 @@ from vertexai.preview.generative_models import ( ) +def _format_pydantic_to_vertex_function( + pydantic_model: Type[BaseModel], +) -> FunctionDescription: + schema = dereference_refs(pydantic_model.schema()) + schema.pop("definitions", None) + + return { + "name": schema["title"], + "description": schema["description"], + "parameters": { + "properties": { + k: { + "type": v["type"], + "description": v.get("description"), + } + for k, v in schema["properties"].items() + }, + "required": schema["required"], + "type": schema["type"], + }, + } + + def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription: "Format tool into the Vertex function API." if tool.args_schema: @@ -46,11 +74,81 @@ def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription: } -def _format_tools_to_vertex_tool(tools: List[Tool]) -> List[VertexTool]: +def _format_tools_to_vertex_tool( + tools: List[Union[Tool, Type[BaseModel]]], +) -> List[VertexTool]: "Format tool into the Vertex Tool instance." function_declarations = [] for tool in tools: - func = _format_tool_to_vertex_function(tool) + if isinstance(tool, Tool): + func = _format_tool_to_vertex_function(tool) + else: + func = _format_pydantic_to_vertex_function(tool) function_declarations.append(FunctionDeclaration(**func)) return [VertexTool(function_declarations=function_declarations)] + + +class PydanticFunctionsOutputParser(BaseOutputParser): + """Parse an output as a pydantic object. + + This parser is used to parse the output of a ChatModel that uses + Google Vertex function format to invoke functions. + + The parser extracts the function call invocation and matches + them to the pydantic schema provided. + + An exception will be raised if the function call does not match + the provided schema. + + Example: + + ... code-block:: python + + message = AIMessage( + content="This is a test message", + additional_kwargs={ + "function_call": { + "name": "cookie", + "arguments": json.dumps({"name": "value", "age": 10}), + } + }, + ) + chat_generation = ChatGeneration(message=message) + + class Cookie(BaseModel): + name: str + age: int + + class Dog(BaseModel): + species: str + + # Full output + parser = PydanticOutputFunctionsParser( + pydantic_schema={"cookie": Cookie, "dog": Dog} + ) + result = parser.parse_result([chat_generation]) + """ + + pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] + + def parse_result( + self, result: List[Generation], *, partial: bool = False + ) -> BaseModel: + if not isinstance(result[0], ChatGeneration): + raise ValueError("This output parser only works on ChatGeneration output") + message = result[0].message + function_call = message.additional_kwargs.get("function_call", {}) + if function_call: + function_name = function_call["name"] + tool_input = function_call.get("arguments", {}) + if isinstance(self.pydantic_schema, dict): + schema = self.pydantic_schema[function_name] + else: + schema = self.pydantic_schema + return schema(**json.loads(tool_input)) + else: + raise OutputParserException(f"Could not parse function call: {message}") + + def parse(self, text: str) -> BaseModel: + raise ValueError("Can only parse messages") diff --git a/libs/partners/google-vertexai/tests/unit_tests/test_imports.py b/libs/partners/google-vertexai/tests/unit_tests/test_imports.py index 11e91afcbe..7afa74f1dc 100644 --- a/libs/partners/google-vertexai/tests/unit_tests/test_imports.py +++ b/libs/partners/google-vertexai/tests/unit_tests/test_imports.py @@ -7,6 +7,8 @@ EXPECTED_ALL = [ "VertexAIModelGarden", "HarmBlockThreshold", "HarmCategory", + "PydanticFunctionsOutputParser", + "create_structured_runnable", ]