[experimental][llms][OllamaFunctions] tool calling related fixes (#22339)

Fixes issues with tool calling to handle tool objects correctly. Added
support to handle ToolMessage correctly.
Added additional checks for error conditions.

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
pull/22627/head
Karim Lalani 4 weeks ago committed by GitHub
parent d04e899b56
commit 276be6cdd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -13,6 +13,7 @@ from typing import (
TypedDict,
TypeVar,
Union,
cast,
overload,
)
@ -22,7 +23,14 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage, BaseMessage, ToolCall
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser
@ -74,18 +82,32 @@ def _is_pydantic_class(obj: Any) -> bool:
)
def _is_pydantic_object(obj: Any) -> bool:
return isinstance(obj, BaseModel)
def convert_to_ollama_tool(tool: Any) -> Dict:
"""Convert a tool to an Ollama tool."""
description = None
if _is_pydantic_class(tool):
schema = tool.construct().schema()
definition = {"name": schema["title"], "properties": schema["properties"]}
if "required" in schema:
definition["required"] = schema["required"]
name = schema["title"]
elif _is_pydantic_object(tool):
schema = tool.get_input_schema().schema()
name = tool.get_name()
description = tool.description
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
return tool.copy()
else:
raise ValueError(
f"""Cannot convert {tool} to an Ollama tool.
{tool} needs to be a Pydantic class, model, or a dict."""
)
definition = {"name": name, "parameters": schema}
if description:
definition["description"] = description
return definition
raise ValueError(
f"Cannot convert {tool} to an Ollama tool. {tool} needs to be a Pydantic model."
)
return definition
class _AllReturnType(TypedDict):
@ -280,6 +302,59 @@ class OllamaFunctions(ChatOllama):
else:
return llm | parser_chain
def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages: List = []
for message in messages:
role = ""
if isinstance(message, HumanMessage):
role = "user"
elif isinstance(message, AIMessage) or isinstance(message, ToolMessage):
role = "assistant"
elif isinstance(message, SystemMessage):
role = "system"
else:
raise ValueError("Received unsupported message type for Ollama.")
content = ""
images = []
if isinstance(message.content, str):
content = message.content
else:
for content_part in cast(List[Dict], message.content):
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
images.append(image_url_components[1])
else:
images.append(image_url_components[0])
else:
raise ValueError(
"Only string image_url content parts are supported."
)
else:
raise ValueError(
"Unsupported message content type. "
"Must either have type 'text' or type 'image_url' "
"with a string 'image_url' field."
)
ollama_messages.append(
{
"role": role,
"content": content,
"images": images,
}
)
return ollama_messages
def _generate(
self,
messages: List[BaseMessage],
@ -300,9 +375,8 @@ class OllamaFunctions(ChatOllama):
"matching function in `functions`."
)
del kwargs["function_call"]
if _is_pydantic_class(functions[0]):
functions = [convert_to_ollama_tool(fn) for fn in functions]
functions.insert(0, DEFAULT_RESPONSE_FUNCTION)
functions = [convert_to_ollama_tool(fn) for fn in functions]
functions.append(DEFAULT_RESPONSE_FUNCTION)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
@ -323,16 +397,16 @@ class OllamaFunctions(ChatOllama):
Please try again.
Response: {chat_generation_content}"""
)
called_tool_name = parsed_chat_result["tool"]
called_tool_name = (
parsed_chat_result["tool"] if "tool" in parsed_chat_result else None
)
called_tool = next(
(fn for fn in functions if fn["name"] == called_tool_name), None
)
if called_tool is None:
raise ValueError(
f"Failed to parse a function call from {self.model} output: "
f"{chat_generation_content}"
)
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
if (
called_tool is None
or called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]
):
if (
"tool_input" in parsed_chat_result
and "response" in parsed_chat_result["tool_input"]
@ -355,7 +429,11 @@ class OllamaFunctions(ChatOllama):
]
)
called_tool_arguments = parsed_chat_result["tool_input"]
called_tool_arguments = (
parsed_chat_result["tool_input"]
if "tool_input" in parsed_chat_result
else {}
)
response_message_with_functions = AIMessage(
content="",

@ -2,6 +2,8 @@
import unittest
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_community.tools.pubmed.tool import PubmedQueryRun
from langchain_core.messages import AIMessage
from langchain_core.pydantic_v1 import BaseModel, Field
@ -22,7 +24,7 @@ class TestOllamaFunctions(unittest.TestCase):
"""
def test_default_ollama_functions(self) -> None:
base_model = OllamaFunctions(model="llama3", format="json")
base_model = OllamaFunctions(model="phi3", format="json")
# bind functions
model = base_model.bind_tools(
@ -60,8 +62,22 @@ class TestOllamaFunctions(unittest.TestCase):
assert tool_call
self.assertEqual("get_current_weather", tool_call.get("name"))
def test_ollama_functions_tools(self) -> None:
base_model = OllamaFunctions(model="phi3", format="json")
model = base_model.bind_tools(
tools=[PubmedQueryRun(), DuckDuckGoSearchResults(max_results=2)]
)
res = model.invoke("What causes lung cancer?")
self.assertIsInstance(res, AIMessage)
res = AIMessage(**res.__dict__)
tool_calls = res.tool_calls
assert tool_calls
tool_call = tool_calls[0]
assert tool_call
self.assertEqual("pub_med", tool_call.get("name"))
def test_default_ollama_functions_default_response(self) -> None:
base_model = OllamaFunctions(model="llama3", format="json")
base_model = OllamaFunctions(model="phi3", format="json")
# bind functions
model = base_model.bind_tools(

Loading…
Cancel
Save