mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
480626dc99
…tch]: import models from community ran ```bash git grep -l 'from langchain\.chat_models' | xargs -L 1 sed -i '' "s/from\ langchain\.chat_models/from\ langchain_community.chat_models/g" git grep -l 'from langchain\.llms' | xargs -L 1 sed -i '' "s/from\ langchain\.llms/from\ langchain_community.llms/g" git grep -l 'from langchain\.embeddings' | xargs -L 1 sed -i '' "s/from\ langchain\.embeddings/from\ langchain_community.embeddings/g" git checkout master libs/langchain/tests/unit_tests/llms git checkout master libs/langchain/tests/unit_tests/chat_models git checkout master libs/langchain/tests/unit_tests/embeddings/test_imports.py make format cd libs/langchain; make format cd ../experimental; make format cd ../core; make format ```
142 lines
4.8 KiB
Python
142 lines
4.8 KiB
Python
import json
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain_community.chat_models.ollama import ChatOllama
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import AIMessage, BaseMessage
|
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
from langchain_core.prompts import SystemMessagePromptTemplate
|
|
|
|
from langchain_experimental.pydantic_v1 import root_validator
|
|
|
|
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
|
|
|
{tools}
|
|
|
|
You must always select one of the above tools and respond with only a JSON object matching the following schema:
|
|
|
|
{{
|
|
"tool": <name of the selected tool>,
|
|
"tool_input": <parameters for the selected tool, matching the tool's JSON schema>
|
|
}}
|
|
""" # noqa: E501
|
|
|
|
|
|
DEFAULT_RESPONSE_FUNCTION = {
|
|
"name": "__conversational_response",
|
|
"description": (
|
|
"Respond conversationally if no other tools should be called for a given query."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"response": {
|
|
"type": "string",
|
|
"description": "Conversational response to the user.",
|
|
},
|
|
},
|
|
"required": ["response"],
|
|
},
|
|
}
|
|
|
|
|
|
class OllamaFunctions(BaseChatModel):
|
|
llm: ChatOllama
|
|
|
|
tool_system_prompt_template: str
|
|
|
|
@root_validator(pre=True)
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
values["llm"] = values.get("llm") or ChatOllama(**values, format="json")
|
|
values["tool_system_prompt_template"] = (
|
|
values.get("tool_system_prompt_template") or DEFAULT_SYSTEM_TEMPLATE
|
|
)
|
|
return values
|
|
|
|
@property
|
|
def model(self) -> BaseChatModel:
|
|
"""For backwards compatibility."""
|
|
return self.llm
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
functions = kwargs.get("functions", [])
|
|
if "function_call" in kwargs:
|
|
functions = [
|
|
fn for fn in functions if fn["name"] == kwargs["function_call"]["name"]
|
|
]
|
|
if not functions:
|
|
raise ValueError(
|
|
'If "function_call" is specified, you must also pass a matching \
|
|
function in "functions".'
|
|
)
|
|
del kwargs["function_call"]
|
|
elif not functions:
|
|
functions.append(DEFAULT_RESPONSE_FUNCTION)
|
|
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
|
self.tool_system_prompt_template
|
|
)
|
|
system_message = system_message_prompt_template.format(
|
|
tools=json.dumps(functions, indent=2)
|
|
)
|
|
if "functions" in kwargs:
|
|
del kwargs["functions"]
|
|
response_message = self.llm.predict_messages(
|
|
[system_message] + messages, stop=stop, callbacks=run_manager, **kwargs
|
|
)
|
|
chat_generation_content = response_message.content
|
|
if not isinstance(chat_generation_content, str):
|
|
raise ValueError("OllamaFunctions does not support non-string output.")
|
|
try:
|
|
parsed_chat_result = json.loads(chat_generation_content)
|
|
except json.JSONDecodeError:
|
|
raise ValueError(
|
|
f'"{self.llm.model}" did not respond with valid JSON. Please try again.'
|
|
)
|
|
called_tool_name = parsed_chat_result["tool"]
|
|
called_tool_arguments = parsed_chat_result["tool_input"]
|
|
called_tool = next(
|
|
(fn for fn in functions if fn["name"] == called_tool_name), None
|
|
)
|
|
if called_tool is None:
|
|
raise ValueError(
|
|
f"Failed to parse a function call from {self.llm.model} \
|
|
output: {chat_generation_content}"
|
|
)
|
|
if called_tool["name"] == DEFAULT_RESPONSE_FUNCTION["name"]:
|
|
return ChatResult(
|
|
generations=[
|
|
ChatGeneration(
|
|
message=AIMessage(
|
|
content=called_tool_arguments["response"],
|
|
)
|
|
)
|
|
]
|
|
)
|
|
|
|
response_message_with_functions = AIMessage(
|
|
content="",
|
|
additional_kwargs={
|
|
"function_call": {
|
|
"name": called_tool_name,
|
|
"arguments": json.dumps(called_tool_arguments)
|
|
if called_tool_arguments
|
|
else "",
|
|
},
|
|
},
|
|
)
|
|
|
|
return ChatResult(
|
|
generations=[ChatGeneration(message=response_message_with_functions)]
|
|
)
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "ollama_functions"
|