You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
4.4 KiB

"""Chain that does self ask with search."""
from typing import Any, Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.self_ask_with_search.prompt import PROMPT
from langchain.chains.serpapi import SerpAPIChain
from langchain.llms.base import LLM
def extract_answer(generated: str) -> str:
"""Extract answer from text."""
if "\n" not in generated:
last_line = generated
last_line = generated.split("\n")[-1]
if ":" not in last_line:
after_colon = last_line
after_colon = generated.split(":")[-1]
if " " == after_colon[0]:
after_colon = after_colon[1:]
if "." == after_colon[-1]:
after_colon = after_colon[:-1]
return after_colon
def extract_question(generated: str, followup: str) -> str:
"""Extract question from text."""
if "\n" not in generated:
last_line = generated
last_line = generated.split("\n")[-1]
if followup not in last_line:
print("we probably should never get here..." + generated)
if ":" not in last_line:
after_colon = last_line
after_colon = generated.split(":")[-1]
if " " == after_colon[0]:
after_colon = after_colon[1:]
if "?" != after_colon[-1]:
print("we probably should never get here..." + generated)
return after_colon
def get_last_line(generated: str) -> str:
"""Get the last line in text."""
if "\n" not in generated:
last_line = generated
last_line = generated.split("\n")[-1]
return last_line
def greenify(_input: str) -> str:
"""Add green highlighting to text."""
return "\x1b[102m" + _input + "\x1b[0m"
def yellowfy(_input: str) -> str:
"""Add yellow highlighting to text."""
return "\x1b[106m" + _input + "\x1b[0m"
class SelfAskWithSearchChain(Chain, BaseModel):
"""Chain that does self ask with search."""
llm: LLM
search_chain: SerpAPIChain
input_key: str = "question"
output_key: str = "answer"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def input_keys(self) -> List[str]:
"""Expect input key."""
return [self.input_key]
def output_keys(self) -> List[str]:
"""Expect output key."""
return [self.output_key]
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
question = inputs[self.input_key]
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
intermediate = "\nIntermediate answer:"
followup = "Follow up:"
finalans = "\nSo the final answer is:"
cur_prompt = f"{question}\nAre follow up questions needed here:"
print(cur_prompt, end="")
ret_text = llm_chain.predict(input=cur_prompt, stop=[intermediate])
print(greenify(ret_text), end="")
while followup in get_last_line(ret_text):
cur_prompt += ret_text
question = extract_question(ret_text, followup)
external_answer =
if external_answer is not None:
cur_prompt += intermediate + " " + external_answer + "."
intermediate + " " + yellowfy(external_answer) + ".",
ret_text = llm_chain.predict(
input=cur_prompt, stop=["\nIntermediate answer:"]
print(greenify(ret_text), end="")
# We only get here in the very rare case that Google returns no answer.
cur_prompt += intermediate
print(intermediate + " ")
cur_prompt += llm_chain.predict(
input=cur_prompt, stop=["\n" + followup, finalans]
if finalans not in ret_text:
cur_prompt += finalans
print(finalans, end="")
ret_text = llm_chain.predict(input=cur_prompt, stop=["\n"])
print(greenify(ret_text), end="")
return {self.output_key: cur_prompt + ret_text}
def run(self, question: str) -> str:
"""More user-friendly interface for interfacing with self ask with search."""
return self({self.input_key: question})[self.output_key]