|
|
@ -13,6 +13,7 @@ from typing import (
|
|
|
|
TypedDict,
|
|
|
|
TypedDict,
|
|
|
|
TypeVar,
|
|
|
|
TypeVar,
|
|
|
|
Union,
|
|
|
|
Union,
|
|
|
|
|
|
|
|
cast,
|
|
|
|
overload,
|
|
|
|
overload,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -22,7 +23,14 @@ from langchain_core.callbacks import (
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from langchain_core.language_models import LanguageModelInput
|
|
|
|
from langchain_core.language_models import LanguageModelInput
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage, ToolCall
|
|
|
|
from langchain_core.messages import (
|
|
|
|
|
|
|
|
AIMessage,
|
|
|
|
|
|
|
|
BaseMessage,
|
|
|
|
|
|
|
|
HumanMessage,
|
|
|
|
|
|
|
|
SystemMessage,
|
|
|
|
|
|
|
|
ToolCall,
|
|
|
|
|
|
|
|
ToolMessage,
|
|
|
|
|
|
|
|
)
|
|
|
|
from langchain_core.output_parsers.base import OutputParserLike
|
|
|
|
from langchain_core.output_parsers.base import OutputParserLike
|
|
|
|
from langchain_core.output_parsers.json import JsonOutputParser
|
|
|
|
from langchain_core.output_parsers.json import JsonOutputParser
|
|
|
|
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
|
|
|
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
|
|
@ -74,18 +82,32 @@ def _is_pydantic_class(obj: Any) -> bool:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_pydantic_object(obj: Any) -> bool:
|
|
|
|
|
|
|
|
return isinstance(obj, BaseModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_to_ollama_tool(tool: Any) -> Dict:
|
|
|
|
def convert_to_ollama_tool(tool: Any) -> Dict:
|
|
|
|
"""Convert a tool to an Ollama tool."""
|
|
|
|
"""Convert a tool to an Ollama tool."""
|
|
|
|
|
|
|
|
description = None
|
|
|
|
if _is_pydantic_class(tool):
|
|
|
|
if _is_pydantic_class(tool):
|
|
|
|
schema = tool.construct().schema()
|
|
|
|
schema = tool.construct().schema()
|
|
|
|
definition = {"name": schema["title"], "properties": schema["properties"]}
|
|
|
|
name = schema["title"]
|
|
|
|
if "required" in schema:
|
|
|
|
elif _is_pydantic_object(tool):
|
|
|
|
definition["required"] = schema["required"]
|
|
|
|
schema = tool.get_input_schema().schema()
|
|
|
|
|
|
|
|
name = tool.get_name()
|
|
|
|
|
|
|
|
description = tool.description
|
|
|
|
|
|
|
|
elif isinstance(tool, dict) and "name" in tool and "parameters" in tool:
|
|
|
|
|
|
|
|
return tool.copy()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
f"""Cannot convert {tool} to an Ollama tool.
|
|
|
|
|
|
|
|
{tool} needs to be a Pydantic class, model, or a dict."""
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
definition = {"name": name, "parameters": schema}
|
|
|
|
|
|
|
|
if description:
|
|
|
|
|
|
|
|
definition["description"] = description
|
|
|
|
|
|
|
|
|
|
|
|
return definition
|
|
|
|
return definition
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
f"Cannot convert {tool} to an Ollama tool. {tool} needs to be a Pydantic model."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AllReturnType(TypedDict):
|
|
|
|
class _AllReturnType(TypedDict):
|
|
|
@ -280,6 +302,59 @@ class OllamaFunctions(ChatOllama):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
return llm | parser_chain
|
|
|
|
return llm | parser_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_messages_to_ollama_messages(
|
|
|
|
|
|
|
|
self, messages: List[BaseMessage]
|
|
|
|
|
|
|
|
) -> List[Dict[str, Union[str, List[str]]]]:
|
|
|
|
|
|
|
|
ollama_messages: List = []
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
|
|
|
role = ""
|
|
|
|
|
|
|
|
if isinstance(message, HumanMessage):
|
|
|
|
|
|
|
|
role = "user"
|
|
|
|
|
|
|
|
elif isinstance(message, AIMessage) or isinstance(message, ToolMessage):
|
|
|
|
|
|
|
|
role = "assistant"
|
|
|
|
|
|
|
|
elif isinstance(message, SystemMessage):
|
|
|
|
|
|
|
|
role = "system"
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError("Received unsupported message type for Ollama.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
content = ""
|
|
|
|
|
|
|
|
images = []
|
|
|
|
|
|
|
|
if isinstance(message.content, str):
|
|
|
|
|
|
|
|
content = message.content
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
for content_part in cast(List[Dict], message.content):
|
|
|
|
|
|
|
|
if content_part.get("type") == "text":
|
|
|
|
|
|
|
|
content += f"\n{content_part['text']}"
|
|
|
|
|
|
|
|
elif content_part.get("type") == "image_url":
|
|
|
|
|
|
|
|
if isinstance(content_part.get("image_url"), str):
|
|
|
|
|
|
|
|
image_url_components = content_part["image_url"].split(",")
|
|
|
|
|
|
|
|
# Support data:image/jpeg;base64,<image> format
|
|
|
|
|
|
|
|
# and base64 strings
|
|
|
|
|
|
|
|
if len(image_url_components) > 1:
|
|
|
|
|
|
|
|
images.append(image_url_components[1])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
images.append(image_url_components[0])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Only string image_url content parts are supported."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Unsupported message content type. "
|
|
|
|
|
|
|
|
"Must either have type 'text' or type 'image_url' "
|
|
|
|
|
|
|
|
"with a string 'image_url' field."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ollama_messages.append(
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
"role": role,
|
|
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
|
|
"images": images,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ollama_messages
|
|
|
|
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
def _generate(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
messages: List[BaseMessage],
|
|
|
@ -300,9 +375,8 @@ class OllamaFunctions(ChatOllama):
|
|
|
|
"matching function in `functions`."
|
|
|
|
"matching function in `functions`."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
del kwargs["function_call"]
|
|
|
|
del kwargs["function_call"]
|
|
|
|
if _is_pydantic_class(functions[0]):
|
|
|
|
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
|
|
|
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
|
|
|
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
|
|
|
functions.insert(0, DEFAULT_RESPONSE_FUNCTION)
|
|
|
|
|
|
|
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
|
|
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
|
|
|
self.tool_system_prompt_template
|
|
|
|
self.tool_system_prompt_template
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -323,16 +397,16 @@ class OllamaFunctions(ChatOllama):
|
|
|
|
Please try again.
|
|
|
|
Please try again.
|
|
|
|
Response: {chat_generation_content}"""
|
|
|
|
Response: {chat_generation_content}"""
|
|
|
|
)
|
|
|
|
)
|
|
|
|
called_tool_name = parsed_chat_result["tool"]
|
|
|
|
called_tool_name = (
|
|
|
|
|
|
|
|
parsed_chat_result["tool"] if "tool" in parsed_chat_result else None
|
|
|
|
|
|
|
|
)
|
|
|
|
called_tool = next(
|
|
|
|
called_tool = next(
|
|
|
|
(fn for fn in functions if fn["name"] == called_tool_name), None
|
|
|
|
(fn for fn in functions if fn["name"] == called_tool_name), None
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if called_tool is None:
|
|
|
|
if (
|
|
|
|
raise ValueError(
|
|
|
|
called_tool is None
|
|
|
|
f"Failed to parse a function call from {self.model} output: "
|
|
|
|
or called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]
|
|
|
|
f"{chat_generation_content}"
|
|
|
|
):
|
|
|
|
)
|
|
|
|
|
|
|
|
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
|
|
|
|
|
|
|
if (
|
|
|
|
if (
|
|
|
|
"tool_input" in parsed_chat_result
|
|
|
|
"tool_input" in parsed_chat_result
|
|
|
|
and "response" in parsed_chat_result["tool_input"]
|
|
|
|
and "response" in parsed_chat_result["tool_input"]
|
|
|
@ -355,7 +429,11 @@ class OllamaFunctions(ChatOllama):
|
|
|
|
]
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
called_tool_arguments = parsed_chat_result["tool_input"]
|
|
|
|
called_tool_arguments = (
|
|
|
|
|
|
|
|
parsed_chat_result["tool_input"]
|
|
|
|
|
|
|
|
if "tool_input" in parsed_chat_result
|
|
|
|
|
|
|
|
else {}
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
response_message_with_functions = AIMessage(
|
|
|
|
response_message_with_functions = AIMessage(
|
|
|
|
content="",
|
|
|
|
content="",
|
|
|
|