2022-11-22 14:16:26 +00:00
|
|
|
"""Chain that does self ask with search."""
|
2022-12-19 02:51:23 +00:00
|
|
|
from typing import Any, List, Optional, Tuple
|
2022-11-22 14:16:26 +00:00
|
|
|
|
2022-12-19 02:51:23 +00:00
|
|
|
from langchain.agents.agent import Agent, AgentExecutor
|
2022-11-22 14:16:26 +00:00
|
|
|
from langchain.agents.self_ask_with_search.prompt import PROMPT
|
|
|
|
from langchain.agents.tools import Tool
|
2022-12-18 21:22:42 +00:00
|
|
|
from langchain.llms.base import BaseLLM
|
2022-11-22 14:16:26 +00:00
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
2022-11-30 04:07:44 +00:00
|
|
|
from langchain.serpapi import SerpAPIWrapper
|
2022-11-22 14:16:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SelfAskWithSearchAgent(Agent):
|
|
|
|
"""Agent for the self-ask-with-search paper."""
|
|
|
|
|
2022-12-19 02:51:23 +00:00
|
|
|
@classmethod
|
|
|
|
def create_prompt(cls, tools: List[Tool]) -> BasePromptTemplate:
|
|
|
|
"""Prompt does not depend on tools."""
|
|
|
|
return PROMPT
|
2022-11-22 14:16:26 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _validate_tools(cls, tools: List[Tool]) -> None:
|
|
|
|
if len(tools) != 1:
|
|
|
|
raise ValueError(f"Exactly one tool must be specified, but got {tools}")
|
|
|
|
tool_names = {tool.name for tool in tools}
|
|
|
|
if tool_names != {"Intermediate Answer"}:
|
|
|
|
raise ValueError(
|
|
|
|
f"Tool name should be Intermediate Answer, got {tool_names}"
|
|
|
|
)
|
|
|
|
|
2022-11-26 23:15:43 +00:00
|
|
|
def _extract_tool_and_input(self, text: str) -> Optional[Tuple[str, str]]:
|
2022-11-22 14:16:26 +00:00
|
|
|
followup = "Follow up:"
|
2022-12-01 18:50:36 +00:00
|
|
|
last_line = text.split("\n")[-1]
|
2022-11-22 14:16:26 +00:00
|
|
|
|
|
|
|
if followup not in last_line:
|
|
|
|
finish_string = "So the final answer is: "
|
|
|
|
if finish_string not in last_line:
|
2022-11-26 23:15:43 +00:00
|
|
|
return None
|
|
|
|
return "Final Answer", last_line[len(finish_string) :]
|
2022-11-22 14:16:26 +00:00
|
|
|
|
2022-12-01 18:50:36 +00:00
|
|
|
after_colon = text.split(":")[-1]
|
2022-11-22 14:16:26 +00:00
|
|
|
|
|
|
|
if " " == after_colon[0]:
|
|
|
|
after_colon = after_colon[1:]
|
|
|
|
|
|
|
|
return "Intermediate Answer", after_colon
|
|
|
|
|
2022-11-26 23:15:43 +00:00
|
|
|
def _fix_text(self, text: str) -> str:
|
2022-12-01 18:50:36 +00:00
|
|
|
return f"{text}\nSo the final answer is:"
|
2022-11-26 23:15:43 +00:00
|
|
|
|
2022-11-22 14:16:26 +00:00
|
|
|
@property
|
|
|
|
def observation_prefix(self) -> str:
|
|
|
|
"""Prefix to append the observation with."""
|
|
|
|
return "Intermediate answer: "
|
|
|
|
|
|
|
|
@property
|
|
|
|
def llm_prefix(self) -> str:
|
|
|
|
"""Prefix to append the LLM call with."""
|
|
|
|
return ""
|
|
|
|
|
|
|
|
@property
|
|
|
|
def starter_string(self) -> str:
|
|
|
|
"""Put this string after user input but before first LLM call."""
|
2022-12-19 02:51:23 +00:00
|
|
|
return "Are follow up questions needed here:"
|
2022-11-22 14:16:26 +00:00
|
|
|
|
|
|
|
|
2022-12-19 02:51:23 +00:00
|
|
|
class SelfAskWithSearchChain(AgentExecutor):
|
2022-11-22 14:16:26 +00:00
|
|
|
"""Chain that does self ask with search.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
|
2022-11-30 04:07:44 +00:00
|
|
|
from langchain import SelfAskWithSearchChain, OpenAI, SerpAPIWrapper
|
|
|
|
search_chain = SerpAPIWrapper()
|
2022-11-22 14:16:26 +00:00
|
|
|
self_ask = SelfAskWithSearchChain(llm=OpenAI(), search_chain=search_chain)
|
|
|
|
"""
|
|
|
|
|
2022-12-18 21:22:42 +00:00
|
|
|
def __init__(self, llm: BaseLLM, search_chain: SerpAPIWrapper, **kwargs: Any):
|
2022-11-22 14:16:26 +00:00
|
|
|
"""Initialize with just an LLM and a search chain."""
|
|
|
|
search_tool = Tool(name="Intermediate Answer", func=search_chain.run)
|
2022-12-19 02:51:23 +00:00
|
|
|
agent = SelfAskWithSearchAgent.from_llm_and_tools(llm, [search_tool])
|
|
|
|
super().__init__(agent=agent, tools=[search_tool], **kwargs)
|