You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
EVAL/core/agents/chat_agent.py

127 lines
4.2 KiB
Python

from typing import Any, List, Optional, Sequence, Tuple
from langchain.agents.agent import Agent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.schema import BaseOutputParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.schema import (
AgentAction,
AIMessage,
BaseLanguageModel,
BaseMessage,
HumanMessage,
)
from langchain.tools.base import BaseTool
from core.prompts.input import EVAL_TOOL_RESPONSE
class ConversationalChatAgent(Agent):
"""An agent designed to hold a conversation in addition to using tools."""
output_parser: BaseOutputParser
@property
def _agent_type(self) -> str:
raise NotImplementedError
@property
def observation_prefix(self) -> str:
"""Prefix to append the observation with."""
return "Observation: "
@property
def llm_prefix(self) -> str:
"""Prefix to append the llm call with."""
return "Thought: "
@classmethod
def create_prompt(
cls,
tools: Sequence[BaseTool],
system_message: str,
human_message: str,
output_parser: BaseOutputParser,
input_variables: Optional[List[str]] = None,
) -> BasePromptTemplate:
tool_strings = "\n".join(
[f"> {tool.name}: {tool.description}" for tool in tools]
)
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = human_message.format(
format_instructions=output_parser.get_format_instructions()
)
final_prompt = format_instructions.format(
tool_names=tool_names, tools=tool_strings
)
if input_variables is None:
input_variables = ["input", "chat_history", "agent_scratchpad"]
messages = [
SystemMessagePromptTemplate.from_template(system_message),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template(final_prompt),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
def _extract_tool_and_input(self, llm_output: str) -> Optional[Tuple[str, str]]:
try:
response = self.output_parser.parse(llm_output)
return response["action"], response["action_input"]
except Exception:
raise ValueError(f"Could not parse LLM output: {llm_output}")
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[BaseMessage]:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts: List[BaseMessage] = []
for action, observation in intermediate_steps:
thoughts.append(AIMessage(content=action.log))
human_message = HumanMessage(
content=EVAL_TOOL_RESPONSE.format(observation=observation)
)
thoughts.append(human_message)
return thoughts
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
system_message: str,
human_message: str,
output_parser: BaseOutputParser,
callback_manager: Optional[BaseCallbackManager] = None,
input_variables: Optional[List[str]] = None,
**kwargs: Any,
) -> Agent:
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
system_message=system_message,
human_message=human_message,
input_variables=input_variables,
output_parser=output_parser,
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=output_parser,
**kwargs,
)