google-vertexai[patch]: Harrison/vertex function calling (#16223)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Harrison Chase 2024-01-18 12:17:40 -08:00 committed by GitHub
parent 6bc6d64a12
commit f60f59d69f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 220 additions and 5 deletions

View File

@ -6,9 +6,9 @@ all: help
# Define a variable for the test file path. # Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/ TEST_FILE ?= tests/unit_tests/
test_integration: TEST_FILE = tests/integration_tests/ integration_tests: TEST_FILE = tests/integration_tests/
test test_integration: test integration_tests:
poetry run pytest $(TEST_FILE) poetry run pytest $(TEST_FILE)
tests: tests:

View File

@ -1,6 +1,8 @@
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
from langchain_google_vertexai.chains import create_structured_runnable
from langchain_google_vertexai.chat_models import ChatVertexAI from langchain_google_vertexai.chat_models import ChatVertexAI
from langchain_google_vertexai.embeddings import VertexAIEmbeddings from langchain_google_vertexai.embeddings import VertexAIEmbeddings
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
__all__ = [ __all__ = [
@ -10,4 +12,6 @@ __all__ = [
"VertexAIModelGarden", "VertexAIModelGarden",
"HarmBlockThreshold", "HarmBlockThreshold",
"HarmCategory", "HarmCategory",
"PydanticFunctionsOutputParser",
"create_structured_runnable",
] ]

View File

@ -0,0 +1,111 @@
from typing import (
Dict,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.output_parsers import (
BaseGenerationOutputParser,
BaseOutputParser,
)
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
def get_output_parser(
functions: Sequence[Type[BaseModel]],
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
"""Get the appropriate function output parser given the user functions.
Args:
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
or a Python function. If a dictionary is passed in, it is assumed to
already be a valid OpenAI function.
Returns:
A PydanticFunctionsOutputParser
"""
function_names = [f.__name__ for f in functions]
if len(functions) > 1:
pydantic_schema: Union[Dict, Type[BaseModel]] = {
name: fn for name, fn in zip(function_names, functions)
}
else:
pydantic_schema = functions[0]
output_parser: Union[
BaseOutputParser, BaseGenerationOutputParser
] = PydanticFunctionsOutputParser(pydantic_schema=pydantic_schema)
return output_parser
def create_structured_runnable(
function: Union[Type[BaseModel], Sequence[Type[BaseModel]]],
llm: Runnable,
*,
prompt: Optional[BasePromptTemplate] = None,
) -> Runnable:
"""Create a runnable sequence that uses OpenAI functions.
Args:
function: Either a single pydantic.BaseModel class or a sequence
of pydantic.BaseModels classes.
For best results, pydantic.BaseModels
should have descriptions of the parameters.
llm: Language model to use,
assumed to support the Google Vertex function-calling API.
prompt: BasePromptTemplate to pass to the model.
Returns:
A runnable sequence that will pass in the given functions to the model when run.
Example:
.. code-block:: python
from typing import Optional
from langchain_google_vertexai import ChatVertexAI, create_structured_runnable
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
class RecordPerson(BaseModel):
\"\"\"Record some identifying information about a person.\"\"\"
name: str = Field(..., description="The person's name")
age: int = Field(..., description="The person's age")
fav_food: Optional[str] = Field(None, description="The person's favorite food")
class RecordDog(BaseModel):
\"\"\"Record some identifying information about a dog.\"\"\"
name: str = Field(..., description="The dog's name")
color: str = Field(..., description="The dog's color")
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
llm = ChatVertexAI(model_name="gemini-pro")
prompt = ChatPromptTemplate.from_template(\"\"\"
You are a world class algorithm for recording entities.
Make calls to the relevant function to record the entities in the following input: {input}
Tip: Make sure to answer in the correct format\"\"\"
)
chain = create_structured_runnable([RecordPerson, RecordDog], llm, prompt=prompt)
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
""" # noqa: E501
if not function:
raise ValueError("Need to pass in at least one function. Received zero.")
functions = function if isinstance(function, Sequence) else [function]
output_parser = get_output_parser(functions)
llm_with_functions = llm.bind(functions=functions)
if prompt is None:
initial_chain = llm_with_functions
else:
initial_chain = prompt | llm_with_functions
return initial_chain | output_parser

View File

@ -1,5 +1,10 @@
from typing import List import json
from typing import Dict, List, Type, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import Tool from langchain_core.tools import Tool
from langchain_core.utils.function_calling import FunctionDescription from langchain_core.utils.function_calling import FunctionDescription
from langchain_core.utils.json_schema import dereference_refs from langchain_core.utils.json_schema import dereference_refs
@ -11,6 +16,29 @@ from vertexai.preview.generative_models import (
) )
def _format_pydantic_to_vertex_function(
pydantic_model: Type[BaseModel],
) -> FunctionDescription:
schema = dereference_refs(pydantic_model.schema())
schema.pop("definitions", None)
return {
"name": schema["title"],
"description": schema["description"],
"parameters": {
"properties": {
k: {
"type": v["type"],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type": schema["type"],
},
}
def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription: def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
"Format tool into the Vertex function API." "Format tool into the Vertex function API."
if tool.args_schema: if tool.args_schema:
@ -46,11 +74,81 @@ def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
} }
def _format_tools_to_vertex_tool(tools: List[Tool]) -> List[VertexTool]: def _format_tools_to_vertex_tool(
tools: List[Union[Tool, Type[BaseModel]]],
) -> List[VertexTool]:
"Format tool into the Vertex Tool instance." "Format tool into the Vertex Tool instance."
function_declarations = [] function_declarations = []
for tool in tools: for tool in tools:
func = _format_tool_to_vertex_function(tool) if isinstance(tool, Tool):
func = _format_tool_to_vertex_function(tool)
else:
func = _format_pydantic_to_vertex_function(tool)
function_declarations.append(FunctionDeclaration(**func)) function_declarations.append(FunctionDeclaration(**func))
return [VertexTool(function_declarations=function_declarations)] return [VertexTool(function_declarations=function_declarations)]
class PydanticFunctionsOutputParser(BaseOutputParser):
"""Parse an output as a pydantic object.
This parser is used to parse the output of a ChatModel that uses
Google Vertex function format to invoke functions.
The parser extracts the function call invocation and matches
them to the pydantic schema provided.
An exception will be raised if the function call does not match
the provided schema.
Example:
... code-block:: python
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
name: str
age: int
class Dog(BaseModel):
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
"""
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> BaseModel:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
function_call = message.additional_kwargs.get("function_call", {})
if function_call:
function_name = function_call["name"]
tool_input = function_call.get("arguments", {})
if isinstance(self.pydantic_schema, dict):
schema = self.pydantic_schema[function_name]
else:
schema = self.pydantic_schema
return schema(**json.loads(tool_input))
else:
raise OutputParserException(f"Could not parse function call: {message}")
def parse(self, text: str) -> BaseModel:
raise ValueError("Can only parse messages")

View File

@ -7,6 +7,8 @@ EXPECTED_ALL = [
"VertexAIModelGarden", "VertexAIModelGarden",
"HarmBlockThreshold", "HarmBlockThreshold",
"HarmCategory", "HarmCategory",
"PydanticFunctionsOutputParser",
"create_structured_runnable",
] ]