mirror of https://github.com/hwchase17/langchain
# Added SmartGPT workflow by providing SmartLLM wrapper around LLMs Edit: As @hwchase17 suggested, this should be a chain, not an LLM. I have adapted the PR. It is used like this: ``` from langchain.prompts import PromptTemplate from langchain.chains import SmartLLMChain from langchain.chat_models import ChatOpenAI hard_question = "I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?" hard_question_prompt = PromptTemplate.from_template(hard_question) llm = ChatOpenAI(model_name="gpt-4") prompt = PromptTemplate.from_template(hard_question) chain = SmartLLMChain(llm=llm, prompt=prompt, verbose=True) chain.run({}) ``` Original text: Added SmartLLM wrapper around LLMs to allow for SmartGPT workflow (as in https://youtu.be/wVzuvf9D9BU). SmartLLM can be used wherever LLM can be used. E.g: ``` smart_llm = SmartLLM(llm=OpenAI()) smart_llm("What would be a good company name for a company that makes colorful socks?") ``` or ``` smart_llm = SmartLLM(llm=OpenAI()) prompt = PromptTemplate( input_variables=["product"], template="What is a good name for a company that makes {product}?", ) chain = LLMChain(llm=smart_llm, prompt=prompt) chain.run("colorful socks") ``` SmartGPT consists of 3 steps: 1. Ideate - generate n possible solutions ("ideas") to user prompt 2. Critique - find flaws in every idea & select best one 3. Resolve - improve upon best idea & return it Fixes #4463 ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: - @hwchase17 - @agola11 Twitter: [@UmerHAdil](https://twitter.com/@UmerHAdil) | Discord: RicChilligerDude#7589 --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/9020/head^2
parent
1d3735a84c
commit
8aab39e3ce
@ -0,0 +1,281 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9e9b7651",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# How to use a SmartLLMChain\n",
|
||||||
|
"\n",
|
||||||
|
"A SmartLLMChain is a form of self-critique chain that can help you if have particularly complex questions to answer. Instead of doing a single LLM pass, it instead performs these 3 steps:\n",
|
||||||
|
"1. Ideation: Pass the user prompt n times through the LLM to get n output proposals (called \"ideas\"), where n is a parameter you can set \n",
|
||||||
|
"2. Critique: The LLM critiques all ideas to find possible flaws and picks the best one \n",
|
||||||
|
"3. Resolve: The LLM tries to improve upon the best idea (as chosen in the critique step) and outputs it. This is then the final output.\n",
|
||||||
|
"\n",
|
||||||
|
"SmartLLMChains are based on the SmartGPT workflow proposed in https://youtu.be/wVzuvf9D9BU.\n",
|
||||||
|
"\n",
|
||||||
|
"Note that SmartLLMChains\n",
|
||||||
|
"- use more LLM passes (ie n+2 instead of just 1)\n",
|
||||||
|
"- only work then the underlying LLM has the capability for reflection, whicher smaller models often don't\n",
|
||||||
|
"- only work with underlying models that return exactly 1 output, not multiple\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook demonstrates how to use a SmartLLMChain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "714dede0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"##### Same LLM for all steps"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d3f7fb22",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"OPENAI_API_KEY\"] = \"...\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "10e5ece6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.prompts import PromptTemplate\n",
|
||||||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
|
"from langchain_experimental.smart_llm import SmartLLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1780da51",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"As example question, we will use \"I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\". This is an example from the original SmartGPT video (https://youtu.be/wVzuvf9D9BU?t=384). While this seems like a very easy question, LLMs struggle do these kinds of questions that involve numbers and physical reasoning.\n",
|
||||||
|
"\n",
|
||||||
|
"As we will see, all 3 initial ideas are completely wrong - even though we're using GPT4! Only when using self-reflection do we get a correct answer. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "054af6b1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"hard_question = \"I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8049cecd",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"So, we first create an LLM and prompt template"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "811ea8e1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prompt = PromptTemplate.from_template(hard_question)\n",
|
||||||
|
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "50b602e4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Now we can create a SmartLLMChain"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "8cd49199",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=3, verbose=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "6a72f276",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Now we can use the SmartLLM as a drop-in replacement for our LLM. E.g.:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "074e5e75",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Entering new SmartLLMChain chain...\u001b[0m\n",
|
||||||
|
"Prompt after formatting:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3mI have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\u001b[0m\n",
|
||||||
|
"Idea 1:\n",
|
||||||
|
"\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n",
|
||||||
|
"2. Pour the water from the 6-liter jug into the 12-liter jug.\n",
|
||||||
|
"3. Fill the 6-liter jug again.\n",
|
||||||
|
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n",
|
||||||
|
"5. The amount of water left in the 6-liter jug will be exactly 6 liters.\u001b[0m\n",
|
||||||
|
"Idea 2:\n",
|
||||||
|
"\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n",
|
||||||
|
"2. Pour the water from the 6-liter jug into the 12-liter jug.\n",
|
||||||
|
"3. Fill the 6-liter jug again.\n",
|
||||||
|
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n",
|
||||||
|
"5. Since the 12-liter jug is now full, there will be 2 liters of water left in the 6-liter jug.\n",
|
||||||
|
"6. Empty the 12-liter jug.\n",
|
||||||
|
"7. Pour the 2 liters of water from the 6-liter jug into the 12-liter jug.\n",
|
||||||
|
"8. Fill the 6-liter jug completely again.\n",
|
||||||
|
"9. Pour the water from the 6-liter jug into the 12-liter jug, which already has 2 liters in it.\n",
|
||||||
|
"10. Now, the 12-liter jug will have exactly 6 liters of water (2 liters from before + 4 liters from the 6-liter jug).\u001b[0m\n",
|
||||||
|
"Idea 3:\n",
|
||||||
|
"\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n",
|
||||||
|
"2. Pour the water from the 6-liter jug into the 12-liter jug.\n",
|
||||||
|
"3. Fill the 6-liter jug again.\n",
|
||||||
|
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n",
|
||||||
|
"5. The amount of water left in the 6-liter jug will be exactly 6 liters.\u001b[0m\n",
|
||||||
|
"Critique:\n",
|
||||||
|
"\u001b[33;1m\u001b[1;3mIdea 1:\n",
|
||||||
|
"1. Fill the 6-liter jug completely. (No flaw)\n",
|
||||||
|
"2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n",
|
||||||
|
"3. Fill the 6-liter jug again. (No flaw)\n",
|
||||||
|
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n",
|
||||||
|
"5. The amount of water left in the 6-liter jug will be exactly 6 liters. (Flaw: This statement is incorrect, as there will be no water left in the 6-liter jug after pouring it into the 12-liter jug.)\n",
|
||||||
|
"\n",
|
||||||
|
"Idea 2:\n",
|
||||||
|
"1. Fill the 6-liter jug completely. (No flaw)\n",
|
||||||
|
"2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n",
|
||||||
|
"3. Fill the 6-liter jug again. (No flaw)\n",
|
||||||
|
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n",
|
||||||
|
"5. Since the 12-liter jug is now full, there will be 2 liters of water left in the 6-liter jug. (Flaw: This statement is incorrect, as the 12-liter jug will not be full and there will be no water left in the 6-liter jug.)\n",
|
||||||
|
"6. Empty the 12-liter jug. (No flaw)\n",
|
||||||
|
"7. Pour the 2 liters of water from the 6-liter jug into the 12-liter jug. (Flaw: This step is based on the incorrect assumption that there are 2 liters of water left in the 6-liter jug.)\n",
|
||||||
|
"8. Fill the 6-liter jug completely again. (No flaw)\n",
|
||||||
|
"9. Pour the water from the 6-liter jug into the 12-liter jug, which already has 2 liters in it. (Flaw: This step is based on the incorrect assumption that there are 2 liters of water in the 12-liter jug.)\n",
|
||||||
|
"10. Now, the 12-liter jug will have exactly 6 liters of water (2 liters from before + 4 liters from the 6-liter jug). (Flaw: This conclusion is based on the incorrect assumptions made in the previous steps.)\n",
|
||||||
|
"\n",
|
||||||
|
"Idea 3:\n",
|
||||||
|
"1. Fill the 6-liter jug completely. (No flaw)\n",
|
||||||
|
"2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n",
|
||||||
|
"3. Fill the 6-liter jug again. (No flaw)\n",
|
||||||
|
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n",
|
||||||
|
"5. The amount of water left in the 6-liter jug will be exactly 6 liters. (Flaw: This statement is incorrect, as there will be no water left in the 6-liter jug after pouring it into the 12-liter jug.)\u001b[0m\n",
|
||||||
|
"Resolution:\n",
|
||||||
|
"\u001b[32;1m\u001b[1;3m1. Fill the 12-liter jug completely.\n",
|
||||||
|
"2. Pour the water from the 12-liter jug into the 6-liter jug until the 6-liter jug is full.\n",
|
||||||
|
"3. The amount of water left in the 12-liter jug will be exactly 6 liters.\u001b[0m\n",
|
||||||
|
"\n",
|
||||||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'1. Fill the 12-liter jug completely.\\n2. Pour the water from the 12-liter jug into the 6-liter jug until the 6-liter jug is full.\\n3. The amount of water left in the 12-liter jug will be exactly 6 liters.'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.run({})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bbfebea1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"##### Different LLM for different steps"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5be6ec08",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can also use different LLMs for the different steps by passing `ideation_llm`, `critique_llm` and `resolve_llm`. You might want to do this to use a more creative (i.e., high-temperature) model for ideation and a more strict (i.e., low-temperature) model for critique and resolution."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "9c33fa19",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = SmartLLMChain(\n",
|
||||||
|
" ideation_llm=ChatOpenAI(temperature=0.9, model_name=\"gpt-4\"),\n",
|
||||||
|
" llm=ChatOpenAI(\n",
|
||||||
|
" temperature=0, model_name=\"gpt-4\"\n",
|
||||||
|
" ), # will be used for critqiue and resolution as no specific llms are given\n",
|
||||||
|
" prompt=prompt,\n",
|
||||||
|
" n_ideas=3,\n",
|
||||||
|
" verbose=True,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "886c1cc1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,5 @@
|
|||||||
|
"""Generalized implementation of SmartGPT (origin: https://youtu.be/wVzuvf9D9BU)"""
|
||||||
|
|
||||||
|
from langchain_experimental.smart_llm.base import SmartLLMChain
|
||||||
|
|
||||||
|
__all__ = ["SmartLLMChain"]
|
@ -0,0 +1,323 @@
|
|||||||
|
"""Chain for applying self-critique using the SmartGPT workflow."""
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.input import get_colored_text
|
||||||
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
AIMessagePromptTemplate,
|
||||||
|
BaseMessagePromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
)
|
||||||
|
from langchain.schema import LLMResult, PromptValue
|
||||||
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
|
|
||||||
|
class SmartLLMChain(Chain):
|
||||||
|
"""
|
||||||
|
Generalized implementation of SmartGPT (origin: https://youtu.be/wVzuvf9D9BU)
|
||||||
|
|
||||||
|
A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM
|
||||||
|
performs these 3 steps:
|
||||||
|
1. Ideate: Pass the user prompt to an ideation LLM n_ideas times,
|
||||||
|
each result is an "idea"
|
||||||
|
2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas
|
||||||
|
& picks the best one
|
||||||
|
3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea
|
||||||
|
& outputs only the (improved version of) the best output
|
||||||
|
|
||||||
|
In total, SmartLLMChain pass will use n_ideas+2 LLM calls
|
||||||
|
|
||||||
|
Note that SmartLLMChain will only improve results (compared to a basic LLMChain),
|
||||||
|
when the underlying models have the capability for reflection, which smaller models
|
||||||
|
often don't.
|
||||||
|
|
||||||
|
Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class SmartLLMChainHistory:
|
||||||
|
question: str = ""
|
||||||
|
ideas: List[str] = []
|
||||||
|
critique: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_ideas(self) -> int:
|
||||||
|
return len(self.ideas)
|
||||||
|
|
||||||
|
def ideation_prompt_inputs(self) -> Dict[str, Any]:
|
||||||
|
return {"question": self.question}
|
||||||
|
|
||||||
|
def critique_prompt_inputs(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"question": self.question,
|
||||||
|
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
|
||||||
|
}
|
||||||
|
|
||||||
|
def resolve_prompt_inputs(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"question": self.question,
|
||||||
|
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
|
||||||
|
"critique": self.critique,
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt: BasePromptTemplate
|
||||||
|
"""Prompt object to use."""
|
||||||
|
ideation_llm: Optional[BaseLanguageModel] = None
|
||||||
|
"""LLM to use in ideation step. If None given, 'llm' will be used."""
|
||||||
|
critique_llm: Optional[BaseLanguageModel] = None
|
||||||
|
"""LLM to use in critique step. If None given, 'llm' will be used."""
|
||||||
|
resolver_llm: Optional[BaseLanguageModel] = None
|
||||||
|
"""LLM to use in resolve step. If None given, 'llm' will be used."""
|
||||||
|
llm: Optional[BaseLanguageModel] = None
|
||||||
|
"""LLM to use for each steps, if no specific llm for that step is given. """
|
||||||
|
n_ideas: int = 3
|
||||||
|
"""Number of ideas to generate in idea step"""
|
||||||
|
return_intermediate_steps: bool = False
|
||||||
|
"""Whether to return ideas and critique, in addition to resolution."""
|
||||||
|
history: SmartLLMChainHistory = SmartLLMChainHistory()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
@classmethod
|
||||||
|
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Ensure we have an LLM for each step."""
|
||||||
|
llm = values.get("llm")
|
||||||
|
ideation_llm = values.get("ideation_llm")
|
||||||
|
critique_llm = values.get("critique_llm")
|
||||||
|
resolver_llm = values.get("resolver_llm")
|
||||||
|
|
||||||
|
if not llm and not ideation_llm:
|
||||||
|
raise ValueError(
|
||||||
|
"Either ideation_llm or llm needs to be given. Pass llm, "
|
||||||
|
"if you want to use the same llm for all steps, or pass "
|
||||||
|
"ideation_llm, critique_llm and resolver_llm if you want "
|
||||||
|
"to use different llms for each step."
|
||||||
|
)
|
||||||
|
if not llm and not critique_llm:
|
||||||
|
raise ValueError(
|
||||||
|
"Either critique_llm or llm needs to be given. Pass llm, "
|
||||||
|
"if you want to use the same llm for all steps, or pass "
|
||||||
|
"ideation_llm, critique_llm and resolver_llm if you want "
|
||||||
|
"to use different llms for each step."
|
||||||
|
)
|
||||||
|
if not llm and not resolver_llm:
|
||||||
|
raise ValueError(
|
||||||
|
"Either resolve_llm or llm needs to be given. Pass llm, "
|
||||||
|
"if you want to use the same llm for all steps, or pass "
|
||||||
|
"ideation_llm, critique_llm and resolver_llm if you want "
|
||||||
|
"to use different llms for each step."
|
||||||
|
)
|
||||||
|
if llm and ideation_llm and critique_llm and resolver_llm:
|
||||||
|
raise ValueError(
|
||||||
|
"LLMs are given for each step (ideation_llm, critique_llm,"
|
||||||
|
" resolver_llm), but backup LLM (llm) is also given, which"
|
||||||
|
" would not be used."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Defines the input keys."""
|
||||||
|
return self.prompt.input_variables
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Defines the output keys."""
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
return ["ideas", "critique", "resolution"]
|
||||||
|
return ["resolution"]
|
||||||
|
|
||||||
|
def prep_prompts(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Tuple[PromptValue, Optional[List[str]]]:
|
||||||
|
"""Prepare prompts from inputs."""
|
||||||
|
stop = None
|
||||||
|
if "stop" in inputs:
|
||||||
|
stop = inputs["stop"]
|
||||||
|
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||||
|
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||||
|
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||||
|
_text = "Prompt after formatting:\n" + _colored_text
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||||
|
if "stop" in inputs and inputs["stop"] != stop:
|
||||||
|
raise ValueError(
|
||||||
|
"If `stop` is present in any inputs, should be present in all."
|
||||||
|
)
|
||||||
|
return prompt, stop
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
input_list: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
prompt, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||||
|
self.history.question = prompt.to_string()
|
||||||
|
ideas = self._ideate(stop, run_manager)
|
||||||
|
self.history.ideas = ideas
|
||||||
|
critique = self._critique(stop, run_manager)
|
||||||
|
self.history.critique = critique
|
||||||
|
resolution = self._resolve(stop, run_manager)
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
return {"ideas": ideas, "critique": critique, "resolution": resolution}
|
||||||
|
return {"resolution": resolution}
|
||||||
|
|
||||||
|
def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
|
||||||
|
"""Between steps, only the LLM result text is passed, not the LLMResult object.
|
||||||
|
This function extracts the text from an LLMResult."""
|
||||||
|
if len(result.generations) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"In SmartLLM the LLM result in step {step} is not "
|
||||||
|
"exactly 1 element. This should never happen"
|
||||||
|
)
|
||||||
|
if len(result.generations[0]) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"In SmartLLM the LLM in step {step} returned more than "
|
||||||
|
"1 output. SmartLLM only works with LLMs returning "
|
||||||
|
"exactly 1 output."
|
||||||
|
)
|
||||||
|
return result.generations[0][0].text
|
||||||
|
|
||||||
|
def get_prompt_strings(
|
||||||
|
self, stage: str
|
||||||
|
) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]:
|
||||||
|
role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = []
|
||||||
|
role_strings.append(
|
||||||
|
(
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
"Question: {question}\nAnswer: Let's work this out in a step by "
|
||||||
|
"step way to be sure we have the right answer:",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if stage == "ideation":
|
||||||
|
return role_strings
|
||||||
|
role_strings.extend(
|
||||||
|
[
|
||||||
|
*[
|
||||||
|
(
|
||||||
|
AIMessagePromptTemplate,
|
||||||
|
"Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}",
|
||||||
|
)
|
||||||
|
for i in range(self.n_ideas)
|
||||||
|
],
|
||||||
|
(
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
"You are a researcher tasked with investigating the "
|
||||||
|
f"{self.n_ideas} response options provided. List the flaws and "
|
||||||
|
"faulty logic of each answer options. Let'w work this out in a step"
|
||||||
|
" by step way to be sure we have all the errors:",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if stage == "critique":
|
||||||
|
return role_strings
|
||||||
|
role_strings.extend(
|
||||||
|
[
|
||||||
|
(AIMessagePromptTemplate, "Critique: {critique}"),
|
||||||
|
(
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
"You are a resolved tasked with 1) finding which of "
|
||||||
|
f"the {self.n_ideas} anwer options the researcher thought was "
|
||||||
|
"best,2) improving that answer and 3) printing the answer in full. "
|
||||||
|
"Don't output anything for step 1 or 2, only the full answer in 3. "
|
||||||
|
"Let's work this out in a step by step way to be sure we have "
|
||||||
|
"the right answer:",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if stage == "resolve":
|
||||||
|
return role_strings
|
||||||
|
raise ValueError(
|
||||||
|
"stage should be either 'ideation', 'critique' or 'resolve',"
|
||||||
|
f" but it is '{stage}'. This should never happen."
|
||||||
|
)
|
||||||
|
|
||||||
|
def ideation_prompt(self) -> ChatPromptTemplate:
|
||||||
|
return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation"))
|
||||||
|
|
||||||
|
def critique_prompt(self) -> ChatPromptTemplate:
|
||||||
|
return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique"))
|
||||||
|
|
||||||
|
def resolve_prompt(self) -> ChatPromptTemplate:
|
||||||
|
return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve"))
|
||||||
|
|
||||||
|
def _ideate(
|
||||||
|
self,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Generate n_ideas ideas as response to user prompt."""
|
||||||
|
llm = self.ideation_llm if self.ideation_llm else self.llm
|
||||||
|
prompt = self.ideation_prompt().format_prompt(
|
||||||
|
**self.history.ideation_prompt_inputs()
|
||||||
|
)
|
||||||
|
callbacks = run_manager.get_child() if run_manager else None
|
||||||
|
if llm:
|
||||||
|
ideas = [
|
||||||
|
self._get_text_from_llm_result(
|
||||||
|
llm.generate_prompt([prompt], stop, callbacks),
|
||||||
|
step="ideate",
|
||||||
|
)
|
||||||
|
for _ in range(self.n_ideas)
|
||||||
|
]
|
||||||
|
for i, idea in enumerate(ideas):
|
||||||
|
_colored_text = get_colored_text(idea, "blue")
|
||||||
|
_text = f"Idea {i+1}:\n" + _colored_text
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||||
|
return ideas
|
||||||
|
else:
|
||||||
|
raise ValueError("llm is none, which should never happen")
|
||||||
|
|
||||||
|
def _critique(
|
||||||
|
self,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Critique each of the ideas from ideation stage & select best one."""
|
||||||
|
llm = self.critique_llm if self.critique_llm else self.llm
|
||||||
|
prompt = self.critique_prompt().format_prompt(
|
||||||
|
**self.history.critique_prompt_inputs()
|
||||||
|
)
|
||||||
|
callbacks = run_manager.handlers if run_manager else None
|
||||||
|
if llm:
|
||||||
|
critique = self._get_text_from_llm_result(
|
||||||
|
llm.generate_prompt([prompt], stop, callbacks), step="critique"
|
||||||
|
)
|
||||||
|
_colored_text = get_colored_text(critique, "yellow")
|
||||||
|
_text = "Critique:\n" + _colored_text
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||||
|
return critique
|
||||||
|
else:
|
||||||
|
raise ValueError("llm is none, which should never happen")
|
||||||
|
|
||||||
|
def _resolve(
|
||||||
|
self,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Improve upon the best idea as chosen in critique step & return it."""
|
||||||
|
llm = self.resolver_llm if self.resolver_llm else self.llm
|
||||||
|
prompt = self.resolve_prompt().format_prompt(
|
||||||
|
**self.history.resolve_prompt_inputs()
|
||||||
|
)
|
||||||
|
callbacks = run_manager.handlers if run_manager else None
|
||||||
|
if llm:
|
||||||
|
resolution = self._get_text_from_llm_result(
|
||||||
|
llm.generate_prompt([prompt], stop, callbacks), step="resolve"
|
||||||
|
)
|
||||||
|
_colored_text = get_colored_text(resolution, "green")
|
||||||
|
_text = "Resolution:\n" + _colored_text
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||||
|
return resolution
|
||||||
|
else:
|
||||||
|
raise ValueError("llm is none, which should never happen")
|
@ -0,0 +1,120 @@
|
|||||||
|
"""Test SmartLLM."""
|
||||||
|
from langchain.chat_models import FakeListChatModel
|
||||||
|
from langchain.llms import FakeListLLM
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
from langchain_experimental.smart_llm import SmartLLMChain
|
||||||
|
|
||||||
|
|
||||||
|
def test_ideation() -> None:
|
||||||
|
# test that correct responses are returned
|
||||||
|
responses = ["Idea 1", "Idea 2", "Idea 3"]
|
||||||
|
llm = FakeListLLM(responses=responses)
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["product"],
|
||||||
|
template="What is a good name for a company that makes {product}?",
|
||||||
|
)
|
||||||
|
chain = SmartLLMChain(llm=llm, prompt=prompt)
|
||||||
|
prompt_value, _ = chain.prep_prompts({"product": "socks"})
|
||||||
|
chain.history.question = prompt_value.to_string()
|
||||||
|
results = chain._ideate()
|
||||||
|
assert results == responses
|
||||||
|
|
||||||
|
# test that correct number of responses are returned
|
||||||
|
for i in range(1, 5):
|
||||||
|
responses = [f"Idea {j+1}" for j in range(i)]
|
||||||
|
llm = FakeListLLM(responses=responses)
|
||||||
|
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=i)
|
||||||
|
prompt_value, _ = chain.prep_prompts({"product": "socks"})
|
||||||
|
chain.history.question = prompt_value.to_string()
|
||||||
|
results = chain._ideate()
|
||||||
|
assert len(results) == i
|
||||||
|
|
||||||
|
|
||||||
|
def test_critique() -> None:
|
||||||
|
response = "Test Critique"
|
||||||
|
llm = FakeListLLM(responses=[response])
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["product"],
|
||||||
|
template="What is a good name for a company that makes {product}?",
|
||||||
|
)
|
||||||
|
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2)
|
||||||
|
prompt_value, _ = chain.prep_prompts({"product": "socks"})
|
||||||
|
chain.history.question = prompt_value.to_string()
|
||||||
|
chain.history.ideas = ["Test Idea 1", "Test Idea 2"]
|
||||||
|
result = chain._critique()
|
||||||
|
assert result == response
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver() -> None:
|
||||||
|
response = "Test resolution"
|
||||||
|
llm = FakeListLLM(responses=[response])
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["product"],
|
||||||
|
template="What is a good name for a company that makes {product}?",
|
||||||
|
)
|
||||||
|
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2)
|
||||||
|
prompt_value, _ = chain.prep_prompts({"product": "socks"})
|
||||||
|
chain.history.question = prompt_value.to_string()
|
||||||
|
chain.history.ideas = ["Test Idea 1", "Test Idea 2"]
|
||||||
|
chain.history.critique = "Test Critique"
|
||||||
|
result = chain._resolve()
|
||||||
|
assert result == response
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_steps() -> None:
|
||||||
|
joke = "Why did the chicken cross the Mobius strip?"
|
||||||
|
response = "Resolution response"
|
||||||
|
ideation_llm = FakeListLLM(responses=["Ideation response" for _ in range(20)])
|
||||||
|
critique_llm = FakeListLLM(responses=["Critique response" for _ in range(20)])
|
||||||
|
resolver_llm = FakeListLLM(responses=[response for _ in range(20)])
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["joke"],
|
||||||
|
template="Explain this joke to me: {joke}?",
|
||||||
|
)
|
||||||
|
chain = SmartLLMChain(
|
||||||
|
ideation_llm=ideation_llm,
|
||||||
|
critique_llm=critique_llm,
|
||||||
|
resolver_llm=resolver_llm,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
result = chain(joke)
|
||||||
|
assert result["joke"] == joke
|
||||||
|
assert result["resolution"] == response
|
||||||
|
|
||||||
|
|
||||||
|
def test_intermediate_output() -> None:
|
||||||
|
joke = "Why did the chicken cross the Mobius strip?"
|
||||||
|
llm = FakeListLLM(responses=[f"Response {i+1}" for i in range(5)])
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["joke"],
|
||||||
|
template="Explain this joke to me: {joke}?",
|
||||||
|
)
|
||||||
|
chain = SmartLLMChain(llm=llm, prompt=prompt, return_intermediate_steps=True)
|
||||||
|
result = chain(joke)
|
||||||
|
assert result["joke"] == joke
|
||||||
|
assert result["ideas"] == [f"Response {i+1}" for i in range(3)]
|
||||||
|
assert result["critique"] == "Response 4"
|
||||||
|
assert result["resolution"] == "Response 5"
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_steps_with_chat_model() -> None:
|
||||||
|
joke = "Why did the chicken cross the Mobius strip?"
|
||||||
|
response = "Resolution response"
|
||||||
|
|
||||||
|
ideation_llm = FakeListChatModel(responses=["Ideation response" for _ in range(20)])
|
||||||
|
critique_llm = FakeListChatModel(responses=["Critique response" for _ in range(20)])
|
||||||
|
resolver_llm = FakeListChatModel(responses=[response for _ in range(20)])
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["joke"],
|
||||||
|
template="Explain this joke to me: {joke}?",
|
||||||
|
)
|
||||||
|
chain = SmartLLMChain(
|
||||||
|
ideation_llm=ideation_llm,
|
||||||
|
critique_llm=critique_llm,
|
||||||
|
resolver_llm=resolver_llm,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
result = chain(joke)
|
||||||
|
assert result["joke"] == joke
|
||||||
|
assert result["resolution"] == response
|
Loading…
Reference in New Issue