"""Chain for applying self-critique using the SmartGPT workflow.""" from typing import Any, Dict, List, Optional, Tuple, Type from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.input import get_colored_text from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ( AIMessagePromptTemplate, BaseMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, ) from langchain.schema import LLMResult, PromptValue from langchain_experimental.pydantic_v1 import Extra, root_validator class SmartLLMChain(Chain): """Chain for applying self-critique using the SmartGPT workflow. See details at https://youtu.be/wVzuvf9D9BU A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM performs these 3 steps: 1. Ideate: Pass the user prompt to an ideation LLM n_ideas times, each result is an "idea" 2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas & picks the best one 3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea & outputs only the (improved version of) the best output In total, SmartLLMChain pass will use n_ideas+2 LLM calls Note that SmartLLMChain will only improve results (compared to a basic LLMChain), when the underlying models have the capability for reflection, which smaller models often don't. Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result. """ class SmartLLMChainHistory: question: str = "" ideas: List[str] = [] critique: str = "" @property def n_ideas(self) -> int: return len(self.ideas) def ideation_prompt_inputs(self) -> Dict[str, Any]: return {"question": self.question} def critique_prompt_inputs(self) -> Dict[str, Any]: return { "question": self.question, **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)}, } def resolve_prompt_inputs(self) -> Dict[str, Any]: return { "question": self.question, **{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)}, "critique": self.critique, } prompt: BasePromptTemplate """Prompt object to use.""" output_key: str = "resolution" ideation_llm: Optional[BaseLanguageModel] = None """LLM to use in ideation step. If None given, 'llm' will be used.""" critique_llm: Optional[BaseLanguageModel] = None """LLM to use in critique step. If None given, 'llm' will be used.""" resolver_llm: Optional[BaseLanguageModel] = None """LLM to use in resolve step. If None given, 'llm' will be used.""" llm: Optional[BaseLanguageModel] = None """LLM to use for each steps, if no specific llm for that step is given. """ n_ideas: int = 3 """Number of ideas to generate in idea step""" return_intermediate_steps: bool = False """Whether to return ideas and critique, in addition to resolution.""" history: SmartLLMChainHistory = SmartLLMChainHistory() class Config: extra = Extra.forbid # TODO: move away from `root_validator` since it is deprecated in pydantic v2 # and causes mypy type-checking failures (hence the `type: ignore`) @root_validator # type: ignore[call-overload] @classmethod def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Ensure we have an LLM for each step.""" llm = values.get("llm") ideation_llm = values.get("ideation_llm") critique_llm = values.get("critique_llm") resolver_llm = values.get("resolver_llm") if not llm and not ideation_llm: raise ValueError( "Either ideation_llm or llm needs to be given. Pass llm, " "if you want to use the same llm for all steps, or pass " "ideation_llm, critique_llm and resolver_llm if you want " "to use different llms for each step." ) if not llm and not critique_llm: raise ValueError( "Either critique_llm or llm needs to be given. Pass llm, " "if you want to use the same llm for all steps, or pass " "ideation_llm, critique_llm and resolver_llm if you want " "to use different llms for each step." ) if not llm and not resolver_llm: raise ValueError( "Either resolve_llm or llm needs to be given. Pass llm, " "if you want to use the same llm for all steps, or pass " "ideation_llm, critique_llm and resolver_llm if you want " "to use different llms for each step." ) if llm and ideation_llm and critique_llm and resolver_llm: raise ValueError( "LLMs are given for each step (ideation_llm, critique_llm," " resolver_llm), but backup LLM (llm) is also given, which" " would not be used." ) return values @property def input_keys(self) -> List[str]: """Defines the input keys.""" return self.prompt.input_variables @property def output_keys(self) -> List[str]: """Defines the output keys.""" if self.return_intermediate_steps: return ["ideas", "critique", self.output_key] return [self.output_key] def prep_prompts( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Tuple[PromptValue, Optional[List[str]]]: """Prepare prompts from inputs.""" stop = None if "stop" in inputs: stop = inputs["stop"] selected_inputs = {k: inputs[k] for k in self.prompt.input_variables} prompt = self.prompt.format_prompt(**selected_inputs) _colored_text = get_colored_text(prompt.to_string(), "green") _text = "Prompt after formatting:\n" + _colored_text if run_manager: run_manager.on_text(_text, end="\n", verbose=self.verbose) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." ) return prompt, stop def _call( self, input_list: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: prompt, stop = self.prep_prompts(input_list, run_manager=run_manager) self.history.question = prompt.to_string() ideas = self._ideate(stop, run_manager) self.history.ideas = ideas critique = self._critique(stop, run_manager) self.history.critique = critique resolution = self._resolve(stop, run_manager) if self.return_intermediate_steps: return {"ideas": ideas, "critique": critique, self.output_key: resolution} return {self.output_key: resolution} def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str: """Between steps, only the LLM result text is passed, not the LLMResult object. This function extracts the text from an LLMResult.""" if len(result.generations) != 1: raise ValueError( f"In SmartLLM the LLM result in step {step} is not " "exactly 1 element. This should never happen" ) if len(result.generations[0]) != 1: raise ValueError( f"In SmartLLM the LLM in step {step} returned more than " "1 output. SmartLLM only works with LLMs returning " "exactly 1 output." ) return result.generations[0][0].text def get_prompt_strings( self, stage: str ) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]: role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = [] role_strings.append( ( HumanMessagePromptTemplate, "Question: {question}\nAnswer: Let's work this out in a step by " "step way to be sure we have the right answer:", ) ) if stage == "ideation": return role_strings role_strings.extend( [ *[ ( AIMessagePromptTemplate, "Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}", ) for i in range(self.n_ideas) ], ( HumanMessagePromptTemplate, "You are a researcher tasked with investigating the " f"{self.n_ideas} response options provided. List the flaws and " "faulty logic of each answer option. Let's work this out in a step" " by step way to be sure we have all the errors:", ), ] ) if stage == "critique": return role_strings role_strings.extend( [ (AIMessagePromptTemplate, "Critique: {critique}"), ( HumanMessagePromptTemplate, "You are a resolver tasked with 1) finding which of " f"the {self.n_ideas} answer options the researcher thought was " "best, 2) improving that answer and 3) printing the answer in " "full. Don't output anything for step 1 or 2, only the full " "answer in 3. Let's work this out in a step by step way to " "be sure we have the right answer:", ), ] ) if stage == "resolve": return role_strings raise ValueError( "stage should be either 'ideation', 'critique' or 'resolve'," f" but it is '{stage}'. This should never happen." ) def ideation_prompt(self) -> ChatPromptTemplate: return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation")) def critique_prompt(self) -> ChatPromptTemplate: return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique")) def resolve_prompt(self) -> ChatPromptTemplate: return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve")) def _ideate( self, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForChainRun] = None, ) -> List[str]: """Generate n_ideas ideas as response to user prompt.""" llm = self.ideation_llm if self.ideation_llm else self.llm prompt = self.ideation_prompt().format_prompt( **self.history.ideation_prompt_inputs() ) callbacks = run_manager.get_child() if run_manager else None if llm: ideas = [ self._get_text_from_llm_result( llm.generate_prompt([prompt], stop, callbacks), step="ideate", ) for _ in range(self.n_ideas) ] for i, idea in enumerate(ideas): _colored_text = get_colored_text(idea, "blue") _text = f"Idea {i+1}:\n" + _colored_text if run_manager: run_manager.on_text(_text, end="\n", verbose=self.verbose) return ideas else: raise ValueError("llm is none, which should never happen") def _critique( self, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForChainRun] = None, ) -> str: """Critique each of the ideas from ideation stage & select best one.""" llm = self.critique_llm if self.critique_llm else self.llm prompt = self.critique_prompt().format_prompt( **self.history.critique_prompt_inputs() ) callbacks = run_manager.handlers if run_manager else None if llm: critique = self._get_text_from_llm_result( llm.generate_prompt([prompt], stop, callbacks), step="critique" ) _colored_text = get_colored_text(critique, "yellow") _text = "Critique:\n" + _colored_text if run_manager: run_manager.on_text(_text, end="\n", verbose=self.verbose) return critique else: raise ValueError("llm is none, which should never happen") def _resolve( self, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForChainRun] = None, ) -> str: """Improve upon the best idea as chosen in critique step & return it.""" llm = self.resolver_llm if self.resolver_llm else self.llm prompt = self.resolve_prompt().format_prompt( **self.history.resolve_prompt_inputs() ) callbacks = run_manager.handlers if run_manager else None if llm: resolution = self._get_text_from_llm_result( llm.generate_prompt([prompt], stop, callbacks), step="resolve" ) _colored_text = get_colored_text(resolution, "green") _text = "Resolution:\n" + _colored_text if run_manager: run_manager.on_text(_text, end="\n", verbose=self.verbose) return resolution else: raise ValueError("llm is none, which should never happen")