mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
google-vertexai[patch]: Harrison/vertex function calling (#16223)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
6bc6d64a12
commit
f60f59d69f
@ -6,9 +6,9 @@ all: help
|
|||||||
# Define a variable for the test file path.
|
# Define a variable for the test file path.
|
||||||
TEST_FILE ?= tests/unit_tests/
|
TEST_FILE ?= tests/unit_tests/
|
||||||
|
|
||||||
test_integration: TEST_FILE = tests/integration_tests/
|
integration_tests: TEST_FILE = tests/integration_tests/
|
||||||
|
|
||||||
test test_integration:
|
test integration_tests:
|
||||||
poetry run pytest $(TEST_FILE)
|
poetry run pytest $(TEST_FILE)
|
||||||
|
|
||||||
tests:
|
tests:
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
|
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
|
||||||
|
from langchain_google_vertexai.chains import create_structured_runnable
|
||||||
from langchain_google_vertexai.chat_models import ChatVertexAI
|
from langchain_google_vertexai.chat_models import ChatVertexAI
|
||||||
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
||||||
|
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
|
||||||
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -10,4 +12,6 @@ __all__ = [
|
|||||||
"VertexAIModelGarden",
|
"VertexAIModelGarden",
|
||||||
"HarmBlockThreshold",
|
"HarmBlockThreshold",
|
||||||
"HarmCategory",
|
"HarmCategory",
|
||||||
|
"PydanticFunctionsOutputParser",
|
||||||
|
"create_structured_runnable",
|
||||||
]
|
]
|
||||||
|
@ -0,0 +1,111 @@
|
|||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.output_parsers import (
|
||||||
|
BaseGenerationOutputParser,
|
||||||
|
BaseOutputParser,
|
||||||
|
)
|
||||||
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.runnables import Runnable
|
||||||
|
|
||||||
|
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_parser(
|
||||||
|
functions: Sequence[Type[BaseModel]],
|
||||||
|
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||||
|
"""Get the appropriate function output parser given the user functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
|
||||||
|
or a Python function. If a dictionary is passed in, it is assumed to
|
||||||
|
already be a valid OpenAI function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A PydanticFunctionsOutputParser
|
||||||
|
"""
|
||||||
|
function_names = [f.__name__ for f in functions]
|
||||||
|
if len(functions) > 1:
|
||||||
|
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||||
|
name: fn for name, fn in zip(function_names, functions)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
pydantic_schema = functions[0]
|
||||||
|
output_parser: Union[
|
||||||
|
BaseOutputParser, BaseGenerationOutputParser
|
||||||
|
] = PydanticFunctionsOutputParser(pydantic_schema=pydantic_schema)
|
||||||
|
return output_parser
|
||||||
|
|
||||||
|
|
||||||
|
def create_structured_runnable(
|
||||||
|
function: Union[Type[BaseModel], Sequence[Type[BaseModel]]],
|
||||||
|
llm: Runnable,
|
||||||
|
*,
|
||||||
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
|
) -> Runnable:
|
||||||
|
"""Create a runnable sequence that uses OpenAI functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function: Either a single pydantic.BaseModel class or a sequence
|
||||||
|
of pydantic.BaseModels classes.
|
||||||
|
For best results, pydantic.BaseModels
|
||||||
|
should have descriptions of the parameters.
|
||||||
|
llm: Language model to use,
|
||||||
|
assumed to support the Google Vertex function-calling API.
|
||||||
|
prompt: BasePromptTemplate to pass to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A runnable sequence that will pass in the given functions to the model when run.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from langchain_google_vertexai import ChatVertexAI, create_structured_runnable
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RecordPerson(BaseModel):
|
||||||
|
\"\"\"Record some identifying information about a person.\"\"\"
|
||||||
|
|
||||||
|
name: str = Field(..., description="The person's name")
|
||||||
|
age: int = Field(..., description="The person's age")
|
||||||
|
fav_food: Optional[str] = Field(None, description="The person's favorite food")
|
||||||
|
|
||||||
|
|
||||||
|
class RecordDog(BaseModel):
|
||||||
|
\"\"\"Record some identifying information about a dog.\"\"\"
|
||||||
|
|
||||||
|
name: str = Field(..., description="The dog's name")
|
||||||
|
color: str = Field(..., description="The dog's color")
|
||||||
|
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||||
|
|
||||||
|
|
||||||
|
llm = ChatVertexAI(model_name="gemini-pro")
|
||||||
|
prompt = ChatPromptTemplate.from_template(\"\"\"
|
||||||
|
You are a world class algorithm for recording entities.
|
||||||
|
Make calls to the relevant function to record the entities in the following input: {input}
|
||||||
|
Tip: Make sure to answer in the correct format\"\"\"
|
||||||
|
)
|
||||||
|
chain = create_structured_runnable([RecordPerson, RecordDog], llm, prompt=prompt)
|
||||||
|
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||||
|
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||||
|
""" # noqa: E501
|
||||||
|
if not function:
|
||||||
|
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||||
|
functions = function if isinstance(function, Sequence) else [function]
|
||||||
|
output_parser = get_output_parser(functions)
|
||||||
|
llm_with_functions = llm.bind(functions=functions)
|
||||||
|
if prompt is None:
|
||||||
|
initial_chain = llm_with_functions
|
||||||
|
else:
|
||||||
|
initial_chain = prompt | llm_with_functions
|
||||||
|
return initial_chain | output_parser
|
@ -1,5 +1,10 @@
|
|||||||
from typing import List
|
import json
|
||||||
|
from typing import Dict, List, Type, Union
|
||||||
|
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.tools import Tool
|
from langchain_core.tools import Tool
|
||||||
from langchain_core.utils.function_calling import FunctionDescription
|
from langchain_core.utils.function_calling import FunctionDescription
|
||||||
from langchain_core.utils.json_schema import dereference_refs
|
from langchain_core.utils.json_schema import dereference_refs
|
||||||
@ -11,6 +16,29 @@ from vertexai.preview.generative_models import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_pydantic_to_vertex_function(
|
||||||
|
pydantic_model: Type[BaseModel],
|
||||||
|
) -> FunctionDescription:
|
||||||
|
schema = dereference_refs(pydantic_model.schema())
|
||||||
|
schema.pop("definitions", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": schema["title"],
|
||||||
|
"description": schema["description"],
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
k: {
|
||||||
|
"type": v["type"],
|
||||||
|
"description": v.get("description"),
|
||||||
|
}
|
||||||
|
for k, v in schema["properties"].items()
|
||||||
|
},
|
||||||
|
"required": schema["required"],
|
||||||
|
"type": schema["type"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
|
def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
|
||||||
"Format tool into the Vertex function API."
|
"Format tool into the Vertex function API."
|
||||||
if tool.args_schema:
|
if tool.args_schema:
|
||||||
@ -46,11 +74,81 @@ def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _format_tools_to_vertex_tool(tools: List[Tool]) -> List[VertexTool]:
|
def _format_tools_to_vertex_tool(
|
||||||
|
tools: List[Union[Tool, Type[BaseModel]]],
|
||||||
|
) -> List[VertexTool]:
|
||||||
"Format tool into the Vertex Tool instance."
|
"Format tool into the Vertex Tool instance."
|
||||||
function_declarations = []
|
function_declarations = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
func = _format_tool_to_vertex_function(tool)
|
if isinstance(tool, Tool):
|
||||||
|
func = _format_tool_to_vertex_function(tool)
|
||||||
|
else:
|
||||||
|
func = _format_pydantic_to_vertex_function(tool)
|
||||||
function_declarations.append(FunctionDeclaration(**func))
|
function_declarations.append(FunctionDeclaration(**func))
|
||||||
|
|
||||||
return [VertexTool(function_declarations=function_declarations)]
|
return [VertexTool(function_declarations=function_declarations)]
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticFunctionsOutputParser(BaseOutputParser):
|
||||||
|
"""Parse an output as a pydantic object.
|
||||||
|
|
||||||
|
This parser is used to parse the output of a ChatModel that uses
|
||||||
|
Google Vertex function format to invoke functions.
|
||||||
|
|
||||||
|
The parser extracts the function call invocation and matches
|
||||||
|
them to the pydantic schema provided.
|
||||||
|
|
||||||
|
An exception will be raised if the function call does not match
|
||||||
|
the provided schema.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
... code-block:: python
|
||||||
|
|
||||||
|
message = AIMessage(
|
||||||
|
content="This is a test message",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": {
|
||||||
|
"name": "cookie",
|
||||||
|
"arguments": json.dumps({"name": "value", "age": 10}),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chat_generation = ChatGeneration(message=message)
|
||||||
|
|
||||||
|
class Cookie(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
class Dog(BaseModel):
|
||||||
|
species: str
|
||||||
|
|
||||||
|
# Full output
|
||||||
|
parser = PydanticOutputFunctionsParser(
|
||||||
|
pydantic_schema={"cookie": Cookie, "dog": Dog}
|
||||||
|
)
|
||||||
|
result = parser.parse_result([chat_generation])
|
||||||
|
"""
|
||||||
|
|
||||||
|
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
|
||||||
|
|
||||||
|
def parse_result(
|
||||||
|
self, result: List[Generation], *, partial: bool = False
|
||||||
|
) -> BaseModel:
|
||||||
|
if not isinstance(result[0], ChatGeneration):
|
||||||
|
raise ValueError("This output parser only works on ChatGeneration output")
|
||||||
|
message = result[0].message
|
||||||
|
function_call = message.additional_kwargs.get("function_call", {})
|
||||||
|
if function_call:
|
||||||
|
function_name = function_call["name"]
|
||||||
|
tool_input = function_call.get("arguments", {})
|
||||||
|
if isinstance(self.pydantic_schema, dict):
|
||||||
|
schema = self.pydantic_schema[function_name]
|
||||||
|
else:
|
||||||
|
schema = self.pydantic_schema
|
||||||
|
return schema(**json.loads(tool_input))
|
||||||
|
else:
|
||||||
|
raise OutputParserException(f"Could not parse function call: {message}")
|
||||||
|
|
||||||
|
def parse(self, text: str) -> BaseModel:
|
||||||
|
raise ValueError("Can only parse messages")
|
||||||
|
@ -7,6 +7,8 @@ EXPECTED_ALL = [
|
|||||||
"VertexAIModelGarden",
|
"VertexAIModelGarden",
|
||||||
"HarmBlockThreshold",
|
"HarmBlockThreshold",
|
||||||
"HarmCategory",
|
"HarmCategory",
|
||||||
|
"PydanticFunctionsOutputParser",
|
||||||
|
"create_structured_runnable",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user