google-genai[patch]: match function call interface (#17213)

should match vertex
pull/17215/head
Erick Friis 8 months ago committed by GitHub
parent e17173c403
commit 2ecf318218
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,135 @@
from __future__ import annotations
from typing import (
Dict,
List,
Type,
Union,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.json_schema import dereference_refs
FunctionCallType = Union[BaseTool, Type[BaseModel], Dict]
TYPE_ENUM = {
"string": 1,
"number": 2,
"integer": 3,
"boolean": 4,
"array": 5,
"object": 6,
}
def convert_to_genai_function_declarations(
function_calls: List[FunctionCallType],
) -> Dict:
function_declarations = []
for fc in function_calls:
function_declarations.append(_convert_to_genai_function(fc))
return {
"function_declarations": function_declarations,
}
def _convert_to_genai_function(fc: FunctionCallType) -> Dict:
"""
Produce
{
"name": "get_weather",
"description": "Determine weather in my location",
"parameters": {
"properties": {
"location": {
"description": "The city and state e.g. San Francisco, CA",
"type_": 1
},
"unit": { "enum": ["c", "f"], "type_": 1 }
},
"required": ["location"],
"type_": 6
}
}
"""
if isinstance(fc, BaseTool):
return _convert_tool_to_genai_function(fc)
elif isinstance(fc, type) and issubclass(fc, BaseModel):
return _convert_pydantic_to_genai_function(fc)
elif isinstance(fc, dict):
return {
**fc,
"parameters": {
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in fc["parameters"]["properties"].items()
},
"required": fc["parameters"].get("required", []),
"type_": TYPE_ENUM[fc["parameters"]["type"]],
},
}
else:
raise ValueError(f"Unsupported function call type {fc}")
def _convert_tool_to_genai_function(tool: BaseTool) -> Dict:
if tool.args_schema:
schema = dereference_refs(tool.args_schema.schema())
schema.pop("definitions", None)
return {
"name": tool.name or schema["title"],
"description": tool.description or schema["description"],
"parameters": {
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type_": TYPE_ENUM[schema["type"]],
},
}
else:
return {
"name": tool.name,
"description": tool.description,
"parameters": {
"properties": {
"__arg1": {"type": "string"},
},
"required": ["__arg1"],
"type_": TYPE_ENUM["object"],
},
}
def _convert_pydantic_to_genai_function(
pydantic_model: Type[BaseModel],
) -> Dict:
schema = dereference_refs(pydantic_model.schema())
schema.pop("definitions", None)
return {
"name": schema["title"],
"description": schema.get("description", ""),
"parameters": {
"properties": {
k: {
"type_": TYPE_ENUM[v["type"]],
"description": v.get("description"),
}
for k, v in schema["properties"].items()
},
"required": schema["required"],
"type_": TYPE_ENUM[schema["type"]],
},
}

@ -1,6 +1,7 @@
from __future__ import annotations
import base64
import json
import logging
import os
from io import BytesIO
@ -54,6 +55,9 @@ from tenacity import (
)
from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai._function_utils import (
convert_to_genai_function_declarations,
)
from langchain_google_genai.llms import GoogleModelFamily, _BaseGoogleGenerativeAI
IMAGE_TYPES: Tuple = ()
@ -351,69 +355,14 @@ def _retrieve_function_call_response(
return {
"function_call": {
"name": fc.name,
"arguments": dict(fc.args.items()),
"arguments": json.dumps(
dict(fc.args.items())
), # dump to match other function calling llms for now
}
}
return None
def _convert_function_call_req(function_calls: Union[Dict, List[Dict]]) -> Dict:
function_declarations = []
if isinstance(function_calls, dict):
function_declarations.append(_convert_fc_type(function_calls))
else:
for fc in function_calls:
function_declarations.append(_convert_fc_type(fc))
return {
"function_declarations": function_declarations,
}
def _convert_fc_type(fc: Dict) -> Dict:
# type_: "Type"
# format_: str
# description: str
# nullable: bool
# enum: MutableSequence[str]
# items: "Schema"
# properties: MutableMapping[str, "Schema"]
# required: MutableSequence[str]
if "parameters" in fc:
fc["parameters"] = _convert_fc_type(fc["parameters"])
if "properties" in fc:
for k, v in fc["properties"].items():
fc["properties"][k] = _convert_fc_type(v)
if "type" in fc:
# STRING = 1
# NUMBER = 2
# INTEGER = 3
# BOOLEAN = 4
# ARRAY = 5
# OBJECT = 6
if fc["type"] == "string":
fc["type_"] = 1
elif fc["type"] == "number":
fc["type_"] = 2
elif fc["type"] == "integer":
fc["type_"] = 3
elif fc["type"] == "boolean":
fc["type_"] = 4
elif fc["type"] == "array":
fc["type_"] = 5
elif fc["type"] == "object":
fc["type_"] = 6
del fc["type"]
if "format" in fc:
fc["format_"] = fc["format"]
del fc["format"]
for k, v in fc.items():
if isinstance(v, dict):
fc[k] = _convert_fc_type(v)
return fc
def _parts_to_content(
parts: List[genai.types.PartType],
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
@ -708,11 +657,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
cli = self.client
client = self.client
functions = kwargs.pop("functions", None)
if functions:
tools = _convert_function_call_req(functions)
cli = genai.GenerativeModel(model_name=self.model, tools=tools)
tools = convert_to_genai_function_declarations(functions)
client = genai.GenerativeModel(model_name=self.model, tools=tools)
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
@ -720,7 +669,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = cli.start_chat(history=history)
chat = client.start_chat(history=history)
return params, chat, message
def get_num_tokens(self, text: str) -> int:

@ -1,5 +1,11 @@
"""Test ChatGoogleGenerativeAI function call."""
import json
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import tool
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
)
@ -29,6 +35,50 @@ def test_function_call() -> None:
assert res.additional_kwargs
assert "function_call" in res.additional_kwargs
assert "get_weather" == res.additional_kwargs["function_call"]["name"]
arguments = res.additional_kwargs["function_call"]["arguments"]
assert isinstance(arguments, dict)
arguments_str = res.additional_kwargs["function_call"]["arguments"]
assert isinstance(arguments_str, str)
arguments = json.loads(arguments_str)
assert "location" in arguments
def test_tool_call() -> None:
@tool
def search_tool(query: str) -> str:
"""Searches the web for `query` and returns the result."""
raise NotImplementedError
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[search_tool])
response = llm.invoke("weather in san francisco")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "search_tool"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert "query" in arguments
class MyModel(BaseModel):
name: str
age: int
def test_pydantic_call() -> None:
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=[MyModel])
response = llm.invoke("my name is Erick and I am 27 years old")
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == ""
function_call = response.additional_kwargs.get("function_call")
assert function_call
assert function_call["name"] == "MyModel"
arguments_str = function_call.get("arguments")
assert arguments_str
arguments = json.loads(arguments_str)
assert arguments == {
"name": "Erick",
"age": 27.0,
}

Loading…
Cancel
Save