You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/experimental/langchain_experimental/llms/ollama_functions.py

142 lines
4.8 KiB
Python

import json
from typing import Any, Dict, List, Optional
from langchain.chat_models.ollama import ChatOllama
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_experimental.pydantic_v1 import root_validator
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
{tools}
You must always select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of the selected tool>,
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
}}
""" # noqa: E501
DEFAULT_RESPONSE_FUNCTION = {
"name": "__conversational_response",
"description": (
"Respond conversationally if no other tools should be called for a given query."
),
"parameters": {
"type": "object",
"properties": {
"response": {
"type": "string",
"description": "Conversational response to the user.",
},
},
"required": ["response"],
},
}
class OllamaFunctions(BaseChatModel):
llm: ChatOllama
tool_system_prompt_template: str
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
values["llm"] = values.get("llm") or ChatOllama(**values, format="json")
values["tool_system_prompt_template"] = (
values.get("tool_system_prompt_template") or DEFAULT_SYSTEM_TEMPLATE
)
return values
@property
def model(self) -> BaseChatModel:
"""For backwards compatibility."""
return self.llm
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
functions = kwargs.get("functions", [])
if "function_call" in kwargs:
functions = [
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
]
if not functions:
raise ValueError(
'If "function_call" is specified, you must also pass a matching \
function in "functions".'
)
del kwargs["function_call"]
elif not functions:
functions.append(DEFAULT_RESPONSE_FUNCTION)
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
system_message = system_message_prompt_template.format(
tools=json.dumps(functions, indent=2)
)
if "functions" in kwargs:
del kwargs["functions"]
response_message = self.llm.predict_messages(
[system_message] + messages, stop=stop, callbacks=run_manager, **kwargs
)
chat_generation_content = response_message.content
if not isinstance(chat_generation_content, str):
raise ValueError("OllamaFunctions does not support non-string output.")
try:
parsed_chat_result = json.loads(chat_generation_content)
except json.JSONDecodeError:
raise ValueError(
f'"{self.llm.model}" did not respond with valid JSON. Please try again.'
)
called_tool_name = parsed_chat_result["tool"]
called_tool_arguments = parsed_chat_result["tool_input"]
called_tool = next(
(fn for fn in functions if fn["name"] == called_tool_name), None
)
if called_tool is None:
raise ValueError(
f"Failed to parse a function call from {self.llm.model} \
output: {chat_generation_content}"
)
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(
content=called_tool_arguments["response"],
)
)
]
)
response_message_with_functions = AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": called_tool_name,
"arguments": json.dumps(called_tool_arguments)
if called_tool_arguments
else "",
},
},
)
return ChatResult(
generations=[ChatGeneration(message=response_message_with_functions)]
)
@property
def _llm_type(self) -> str:
return "ollama_functions"