forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
5.4 KiB
Python
156 lines
5.4 KiB
Python
"""An agent designed to hold a conversation in addition to using tools."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any, List, Optional, Sequence, Tuple
|
|
|
|
from langchain.agents.agent import Agent
|
|
from langchain.agents.conversational_chat.prompt import (
|
|
FORMAT_INSTRUCTIONS,
|
|
PREFIX,
|
|
SUFFIX,
|
|
TEMPLATE_TOOL_RESPONSE,
|
|
)
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
from langchain.chains import LLMChain
|
|
from langchain.output_parsers.base import BaseOutputParser
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
from langchain.prompts.chat import (
|
|
ChatPromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
MessagesPlaceholder,
|
|
SystemMessagePromptTemplate,
|
|
)
|
|
from langchain.schema import (
|
|
AgentAction,
|
|
AIMessage,
|
|
BaseLanguageModel,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
)
|
|
from langchain.tools.base import BaseTool
|
|
|
|
|
|
class AgentOutputParser(BaseOutputParser):
|
|
def get_format_instructions(self) -> str:
|
|
return FORMAT_INSTRUCTIONS
|
|
|
|
def parse(self, text: str) -> Any:
|
|
cleaned_output = text.strip()
|
|
if "```json" in cleaned_output:
|
|
_, cleaned_output = cleaned_output.split("```json")
|
|
if cleaned_output.startswith("```json"):
|
|
cleaned_output = cleaned_output[len("```json") :]
|
|
if cleaned_output.startswith("```"):
|
|
cleaned_output = cleaned_output[len("```") :]
|
|
if cleaned_output.endswith("```"):
|
|
cleaned_output = cleaned_output[: -len("```")]
|
|
cleaned_output = cleaned_output.strip()
|
|
response = json.loads(cleaned_output)
|
|
return {"action": response["action"], "action_input": response["action_input"]}
|
|
|
|
|
|
class ConversationalChatAgent(Agent):
|
|
"""An agent designed to hold a conversation in addition to using tools."""
|
|
|
|
output_parser: BaseOutputParser
|
|
|
|
@property
|
|
def _agent_type(self) -> str:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def observation_prefix(self) -> str:
|
|
"""Prefix to append the observation with."""
|
|
return "Observation: "
|
|
|
|
@property
|
|
def llm_prefix(self) -> str:
|
|
"""Prefix to append the llm call with."""
|
|
return "Thought:"
|
|
|
|
@classmethod
|
|
def create_prompt(
|
|
cls,
|
|
tools: Sequence[BaseTool],
|
|
system_message: str = PREFIX,
|
|
human_message: str = SUFFIX,
|
|
input_variables: Optional[List[str]] = None,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
) -> BasePromptTemplate:
|
|
tool_strings = "\n".join(
|
|
[f"> {tool.name}: {tool.description}" for tool in tools]
|
|
)
|
|
tool_names = ", ".join([tool.name for tool in tools])
|
|
_output_parser = output_parser or AgentOutputParser()
|
|
format_instructions = human_message.format(
|
|
format_instructions=_output_parser.get_format_instructions()
|
|
)
|
|
final_prompt = format_instructions.format(
|
|
tool_names=tool_names, tools=tool_strings
|
|
)
|
|
if input_variables is None:
|
|
input_variables = ["input", "chat_history", "agent_scratchpad"]
|
|
messages = [
|
|
SystemMessagePromptTemplate.from_template(system_message),
|
|
MessagesPlaceholder(variable_name="chat_history"),
|
|
HumanMessagePromptTemplate.from_template(final_prompt),
|
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
|
]
|
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
|
|
|
def _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]:
|
|
try:
|
|
response = self.output_parser.parse(llm_output)
|
|
return response["action"], response["action_input"]
|
|
except Exception:
|
|
raise ValueError(f"Could not parse LLM output: {llm_output}")
|
|
|
|
def _construct_scratchpad(
|
|
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
|
) -> List[BaseMessage]:
|
|
"""Construct the scratchpad that lets the agent continue its thought process."""
|
|
thoughts: List[BaseMessage] = []
|
|
for action, observation in intermediate_steps:
|
|
thoughts.append(AIMessage(content=action.log))
|
|
human_message = HumanMessage(
|
|
content=TEMPLATE_TOOL_RESPONSE.format(observation=observation)
|
|
)
|
|
thoughts.append(human_message)
|
|
return thoughts
|
|
|
|
@classmethod
|
|
def from_llm_and_tools(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
tools: Sequence[BaseTool],
|
|
callback_manager: Optional[BaseCallbackManager] = None,
|
|
system_message: str = PREFIX,
|
|
human_message: str = SUFFIX,
|
|
input_variables: Optional[List[str]] = None,
|
|
output_parser: Optional[BaseOutputParser] = None,
|
|
**kwargs: Any,
|
|
) -> Agent:
|
|
"""Construct an agent from an LLM and tools."""
|
|
cls._validate_tools(tools)
|
|
_output_parser = output_parser or AgentOutputParser()
|
|
prompt = cls.create_prompt(
|
|
tools,
|
|
system_message=system_message,
|
|
human_message=human_message,
|
|
input_variables=input_variables,
|
|
output_parser=_output_parser,
|
|
)
|
|
llm_chain = LLMChain(
|
|
llm=llm,
|
|
prompt=prompt,
|
|
callback_manager=callback_manager,
|
|
)
|
|
tool_names = [tool.name for tool in tools]
|
|
return cls(
|
|
llm_chain=llm_chain,
|
|
allowed_tools=tool_names,
|
|
output_parser=_output_parser,
|
|
**kwargs,
|
|
)
|