"""Chain that does self ask with search.""" from typing import Any, Optional, Sequence, Tuple, Union from langchain.agents.agent import Agent, AgentExecutor from langchain.agents.self_ask_with_search.prompt import PROMPT from langchain.agents.tools import Tool from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate from langchain.tools.base import BaseTool from langchain.utilities.google_serper import GoogleSerperAPIWrapper from langchain.utilities.serpapi import SerpAPIWrapper class SelfAskWithSearchAgent(Agent): """Agent for the self-ask-with-search paper.""" @property def _agent_type(self) -> str: """Return Identifier of agent type.""" return "self-ask-with-search" @classmethod def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: """Prompt does not depend on tools.""" return PROMPT @classmethod def _validate_tools(cls, tools: Sequence[BaseTool]) -> None: if len(tools) != 1: raise ValueError(f"Exactly one tool must be specified, but got {tools}") tool_names = {tool.name for tool in tools} if tool_names != {"Intermediate Answer"}: raise ValueError( f"Tool name should be Intermediate Answer, got {tool_names}" ) def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]: followup = "Follow up:" last_line = text.split("\n")[-1] if followup not in last_line: finish_string = "So the final answer is: " if finish_string not in last_line: return None return "Final Answer", last_line[len(finish_string) :] after_colon = text.split(":")[-1] if " " == after_colon[0]: after_colon = after_colon[1:] return "Intermediate Answer", after_colon def _fix_text(self, text: str) -> str: return f"{text}\nSo the final answer is:" @property def observation_prefix(self) -> str: """Prefix to append the observation with.""" return "Intermediate answer: " @property def llm_prefix(self) -> str: """Prefix to append the LLM call with.""" return "" @property def starter_string(self) -> str: """Put this string after user input but before first LLM call.""" return "Are follow up questions needed here:" class SelfAskWithSearchChain(AgentExecutor): """Chain that does self ask with search. Example: .. code-block:: python from langchain import SelfAskWithSearchChain, OpenAI, GoogleSerperAPIWrapper search_chain = GoogleSerperAPIWrapper() self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain) """ def __init__( self, llm: BaseLLM, search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper], **kwargs: Any, ): """Initialize with just an LLM and a search chain.""" search_tool = Tool( name="Intermediate Answer", func=search_chain.run, description="Search" ) agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool]) super().__init__(agent=agent, tools=[search_tool], **kwargs)