From 276be6cdd4f6f220de0862e7246b32df4612dd5a Mon Sep 17 00:00:00 2001 From: Karim Lalani Date: Wed, 12 Jun 2024 15:34:43 -0500 Subject: [PATCH] [experimental][llms][OllamaFunctions] tool calling related fixes (#22339) Fixes issues with tool calling to handle tool objects correctly. Added support to handle ToolMessage correctly. Added additional checks for error conditions. --------- Co-authored-by: ccurme --- .../llms/ollama_functions.py | 116 +++++++++++++++--- .../llms/test_ollama_functions.py | 20 ++- 2 files changed, 115 insertions(+), 21 deletions(-) diff --git a/libs/experimental/langchain_experimental/llms/ollama_functions.py b/libs/experimental/langchain_experimental/llms/ollama_functions.py index d9ba49f7b1..0329ddeecd 100644 --- a/libs/experimental/langchain_experimental/llms/ollama_functions.py +++ b/libs/experimental/langchain_experimental/llms/ollama_functions.py @@ -13,6 +13,7 @@ from typing import ( TypedDict, TypeVar, Union, + cast, overload, ) @@ -22,7 +23,14 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import AIMessage, BaseMessage, ToolCall +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolCall, + ToolMessage, +) from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.output_parsers.pydantic import PydanticOutputParser @@ -74,18 +82,32 @@ def _is_pydantic_class(obj: Any) -> bool: ) +def _is_pydantic_object(obj: Any) -> bool: + return isinstance(obj, BaseModel) + + def convert_to_ollama_tool(tool: Any) -> Dict: """Convert a tool to an Ollama tool.""" + description = None if _is_pydantic_class(tool): schema = tool.construct().schema() - definition = {"name": schema["title"], "properties": schema["properties"]} - if "required" in schema: - definition["required"] = schema["required"] + name = schema["title"] + elif _is_pydantic_object(tool): + schema = tool.get_input_schema().schema() + name = tool.get_name() + description = tool.description + elif isinstance(tool, dict) and "name" in tool and "parameters" in tool: + return tool.copy() + else: + raise ValueError( + f"""Cannot convert {tool} to an Ollama tool. + {tool} needs to be a Pydantic class, model, or a dict.""" + ) + definition = {"name": name, "parameters": schema} + if description: + definition["description"] = description - return definition - raise ValueError( - f"Cannot convert {tool} to an Ollama tool. {tool} needs to be a Pydantic model." - ) + return definition class _AllReturnType(TypedDict): @@ -280,6 +302,59 @@ class OllamaFunctions(ChatOllama): else: return llm | parser_chain + def _convert_messages_to_ollama_messages( + self, messages: List[BaseMessage] + ) -> List[Dict[str, Union[str, List[str]]]]: + ollama_messages: List = [] + for message in messages: + role = "" + if isinstance(message, HumanMessage): + role = "user" + elif isinstance(message, AIMessage) or isinstance(message, ToolMessage): + role = "assistant" + elif isinstance(message, SystemMessage): + role = "system" + else: + raise ValueError("Received unsupported message type for Ollama.") + + content = "" + images = [] + if isinstance(message.content, str): + content = message.content + else: + for content_part in cast(List[Dict], message.content): + if content_part.get("type") == "text": + content += f"\n{content_part['text']}" + elif content_part.get("type") == "image_url": + if isinstance(content_part.get("image_url"), str): + image_url_components = content_part["image_url"].split(",") + # Support data:image/jpeg;base64, format + # and base64 strings + if len(image_url_components) > 1: + images.append(image_url_components[1]) + else: + images.append(image_url_components[0]) + else: + raise ValueError( + "Only string image_url content parts are supported." + ) + else: + raise ValueError( + "Unsupported message content type. " + "Must either have type 'text' or type 'image_url' " + "with a string 'image_url' field." + ) + + ollama_messages.append( + { + "role": role, + "content": content, + "images": images, + } + ) + + return ollama_messages + def _generate( self, messages: List[BaseMessage], @@ -300,9 +375,8 @@ class OllamaFunctions(ChatOllama): "matching function in `functions`." ) del kwargs["function_call"] - if _is_pydantic_class(functions[0]): - functions = [convert_to_ollama_tool(fn) for fn in functions] - functions.insert(0, DEFAULT_RESPONSE_FUNCTION) + functions = [convert_to_ollama_tool(fn) for fn in functions] + functions.append(DEFAULT_RESPONSE_FUNCTION) system_message_prompt_template = SystemMessagePromptTemplate.from_template( self.tool_system_prompt_template ) @@ -323,16 +397,16 @@ class OllamaFunctions(ChatOllama): Please try again. Response: {chat_generation_content}""" ) - called_tool_name = parsed_chat_result["tool"] + called_tool_name = ( + parsed_chat_result["tool"] if "tool" in parsed_chat_result else None + ) called_tool = next( (fn for fn in functions if fn["name"] == called_tool_name), None ) - if called_tool is None: - raise ValueError( - f"Failed to parse a function call from {self.model} output: " - f"{chat_generation_content}" - ) - if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]: + if ( + called_tool is None + or called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"] + ): if ( "tool_input" in parsed_chat_result and "response" in parsed_chat_result["tool_input"] @@ -355,7 +429,11 @@ class OllamaFunctions(ChatOllama): ] ) - called_tool_arguments = parsed_chat_result["tool_input"] + called_tool_arguments = ( + parsed_chat_result["tool_input"] + if "tool_input" in parsed_chat_result + else {} + ) response_message_with_functions = AIMessage( content="", diff --git a/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py b/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py index 6c413a361d..fd7a065135 100644 --- a/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py +++ b/libs/experimental/tests/integration_tests/llms/test_ollama_functions.py @@ -2,6 +2,8 @@ import unittest +from langchain_community.tools import DuckDuckGoSearchResults +from langchain_community.tools.pubmed.tool import PubmedQueryRun from langchain_core.messages import AIMessage from langchain_core.pydantic_v1 import BaseModel, Field @@ -22,7 +24,7 @@ class TestOllamaFunctions(unittest.TestCase): """ def test_default_ollama_functions(self) -> None: - base_model = OllamaFunctions(model="llama3", format="json") + base_model = OllamaFunctions(model="phi3", format="json") # bind functions model = base_model.bind_tools( @@ -60,8 +62,22 @@ class TestOllamaFunctions(unittest.TestCase): assert tool_call self.assertEqual("get_current_weather", tool_call.get("name")) + def test_ollama_functions_tools(self) -> None: + base_model = OllamaFunctions(model="phi3", format="json") + model = base_model.bind_tools( + tools=[PubmedQueryRun(), DuckDuckGoSearchResults(max_results=2)] + ) + res = model.invoke("What causes lung cancer?") + self.assertIsInstance(res, AIMessage) + res = AIMessage(**res.__dict__) + tool_calls = res.tool_calls + assert tool_calls + tool_call = tool_calls[0] + assert tool_call + self.assertEqual("pub_med", tool_call.get("name")) + def test_default_ollama_functions_default_response(self) -> None: - base_model = OllamaFunctions(model="llama3", format="json") + base_model = OllamaFunctions(model="phi3", format="json") # bind functions model = base_model.bind_tools(