forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
5.0 KiB
Python
142 lines
5.0 KiB
Python
import re
|
|
from typing import Any, List, Optional, Sequence, Tuple
|
|
|
|
from pydantic import Field
|
|
|
|
from langchain.agents.agent import Agent, AgentOutputParser
|
|
from langchain.agents.structured_chat.output_parser import (
|
|
StructuredChatOutputParserWithRetries,
|
|
)
|
|
from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
from langchain.prompts.chat import (
|
|
ChatPromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
SystemMessagePromptTemplate,
|
|
)
|
|
from langchain.schema import AgentAction
|
|
from langchain.tools import BaseTool
|
|
|
|
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"
|
|
|
|
|
|
class StructuredChatAgent(Agent):
|
|
output_parser: AgentOutputParser = Field(
|
|
default_factory=StructuredChatOutputParserWithRetries
|
|
)
|
|
|
|
@property
|
|
def observation_prefix(self) -> str:
|
|
"""Prefix to append the observation with."""
|
|
return "Observation: "
|
|
|
|
@property
|
|
def llm_prefix(self) -> str:
|
|
"""Prefix to append the llm call with."""
|
|
return "Thought:"
|
|
|
|
def _construct_scratchpad(
|
|
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
|
) -> str:
|
|
agent_scratchpad = super()._construct_scratchpad(intermediate_steps)
|
|
if not isinstance(agent_scratchpad, str):
|
|
raise ValueError("agent_scratchpad should be of type string.")
|
|
if agent_scratchpad:
|
|
return (
|
|
f"This was your previous work "
|
|
f"(but I haven't seen any of it! I only see what "
|
|
f"you return as final answer):\n{agent_scratchpad}"
|
|
)
|
|
else:
|
|
return agent_scratchpad
|
|
|
|
@classmethod
|
|
def _validate_tools(cls, tools: Sequence[BaseTool]) -> None:
|
|
pass
|
|
|
|
@classmethod
|
|
def _get_default_output_parser(
|
|
cls, llm: Optional[BaseLanguageModel] = None, **kwargs: Any
|
|
) -> AgentOutputParser:
|
|
return StructuredChatOutputParserWithRetries.from_llm(llm=llm)
|
|
|
|
@property
|
|
def _stop(self) -> List[str]:
|
|
return ["Observation:"]
|
|
|
|
@classmethod
|
|
def create_prompt(
|
|
cls,
|
|
tools: Sequence[BaseTool],
|
|
prefix: str = PREFIX,
|
|
suffix: str = SUFFIX,
|
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
input_variables: Optional[List[str]] = None,
|
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
|
) -> BasePromptTemplate:
|
|
tool_strings = []
|
|
for tool in tools:
|
|
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
|
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
|
formatted_tools = "\n".join(tool_strings)
|
|
tool_names = ", ".join([tool.name for tool in tools])
|
|
format_instructions = format_instructions.format(tool_names=tool_names)
|
|
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
|
if input_variables is None:
|
|
input_variables = ["input", "agent_scratchpad"]
|
|
_memory_prompts = memory_prompts or []
|
|
messages = [
|
|
SystemMessagePromptTemplate.from_template(template),
|
|
*_memory_prompts,
|
|
HumanMessagePromptTemplate.from_template(human_message_template),
|
|
]
|
|
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
|
|
|
@classmethod
|
|
def from_llm_and_tools(
|
|
cls,
|
|
llm: BaseLanguageModel,
|
|
tools: Sequence[BaseTool],
|
|
callback_manager: Optional[BaseCallbackManager] = None,
|
|
output_parser: Optional[AgentOutputParser] = None,
|
|
prefix: str = PREFIX,
|
|
suffix: str = SUFFIX,
|
|
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
|
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
input_variables: Optional[List[str]] = None,
|
|
memory_prompts: Optional[List[BasePromptTemplate]] = None,
|
|
**kwargs: Any,
|
|
) -> Agent:
|
|
"""Construct an agent from an LLM and tools."""
|
|
cls._validate_tools(tools)
|
|
prompt = cls.create_prompt(
|
|
tools,
|
|
prefix=prefix,
|
|
suffix=suffix,
|
|
human_message_template=human_message_template,
|
|
format_instructions=format_instructions,
|
|
input_variables=input_variables,
|
|
memory_prompts=memory_prompts,
|
|
)
|
|
llm_chain = LLMChain(
|
|
llm=llm,
|
|
prompt=prompt,
|
|
callback_manager=callback_manager,
|
|
)
|
|
tool_names = [tool.name for tool in tools]
|
|
_output_parser = output_parser or cls._get_default_output_parser(llm=llm)
|
|
return cls(
|
|
llm_chain=llm_chain,
|
|
allowed_tools=tool_names,
|
|
output_parser=_output_parser,
|
|
**kwargs,
|
|
)
|
|
|
|
@property
|
|
def _agent_type(self) -> str:
|
|
raise ValueError
|