|
|
|
@ -17,7 +17,10 @@ from typing import (
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from langchain_community.chat_models.ollama import ChatOllama
|
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.language_models import LanguageModelInput
|
|
|
|
|
from langchain_core.messages import AIMessage, BaseMessage, ToolCall
|
|
|
|
|
from langchain_core.output_parsers.base import OutputParserLike
|
|
|
|
@ -369,6 +372,86 @@ class OllamaFunctions(ChatOllama):
|
|
|
|
|
generations=[ChatGeneration(message=response_message_with_functions)]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
functions = kwargs.get("functions", [])
|
|
|
|
|
if "functions" in kwargs:
|
|
|
|
|
del kwargs["functions"]
|
|
|
|
|
if "function_call" in kwargs:
|
|
|
|
|
functions = [
|
|
|
|
|
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
|
|
|
|
|
]
|
|
|
|
|
if not functions:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"If `function_call` is specified, you must also pass a "
|
|
|
|
|
"matching function in `functions`."
|
|
|
|
|
)
|
|
|
|
|
del kwargs["function_call"]
|
|
|
|
|
elif not functions:
|
|
|
|
|
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
|
|
|
|
if _is_pydantic_class(functions[0]):
|
|
|
|
|
functions = [convert_to_ollama_tool(fn) for fn in functions]
|
|
|
|
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
|
|
|
|
self.tool_system_prompt_template
|
|
|
|
|
)
|
|
|
|
|
system_message = system_message_prompt_template.format(
|
|
|
|
|
tools=json.dumps(functions, indent=2)
|
|
|
|
|
)
|
|
|
|
|
response_message = await super()._agenerate(
|
|
|
|
|
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
|
|
|
|
|
)
|
|
|
|
|
chat_generation_content = response_message.generations[0].text
|
|
|
|
|
if not isinstance(chat_generation_content, str):
|
|
|
|
|
raise ValueError("OllamaFunctions does not support non-string output.")
|
|
|
|
|
try:
|
|
|
|
|
parsed_chat_result = json.loads(chat_generation_content)
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"""'{self.model}' did not respond with valid JSON.
|
|
|
|
|
Please try again.
|
|
|
|
|
Response: {chat_generation_content}"""
|
|
|
|
|
)
|
|
|
|
|
called_tool_name = parsed_chat_result["tool"]
|
|
|
|
|
called_tool_arguments = parsed_chat_result["tool_input"]
|
|
|
|
|
called_tool = next(
|
|
|
|
|
(fn for fn in functions if fn["name"] == called_tool_name), None
|
|
|
|
|
)
|
|
|
|
|
if called_tool is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Failed to parse a function call from {self.model} output: "
|
|
|
|
|
f"{chat_generation_content}"
|
|
|
|
|
)
|
|
|
|
|
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
|
|
|
|
return ChatResult(
|
|
|
|
|
generations=[
|
|
|
|
|
ChatGeneration(
|
|
|
|
|
message=AIMessage(
|
|
|
|
|
content=called_tool_arguments["response"],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
response_message_with_functions = AIMessage(
|
|
|
|
|
content="",
|
|
|
|
|
additional_kwargs={
|
|
|
|
|
"function_call": {
|
|
|
|
|
"name": called_tool_name,
|
|
|
|
|
"arguments": json.dumps(called_tool_arguments)
|
|
|
|
|
if called_tool_arguments
|
|
|
|
|
else "",
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
return ChatResult(
|
|
|
|
|
generations=[ChatGeneration(message=response_message_with_functions)]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
|
return "ollama_functions"
|
|
|
|
|