Added SmartGPT workflow (issue #4463) (#4816)

# Added SmartGPT workflow by providing SmartLLM wrapper around LLMs
Edit:
As @hwchase17 suggested, this should be a chain, not an LLM. I have
adapted the PR.

It is used like this:
```
from langchain.prompts import PromptTemplate
from langchain.chains import SmartLLMChain
from langchain.chat_models import ChatOpenAI

hard_question = "I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?"
hard_question_prompt = PromptTemplate.from_template(hard_question)

llm = ChatOpenAI(model_name="gpt-4")
prompt = PromptTemplate.from_template(hard_question)
chain = SmartLLMChain(llm=llm, prompt=prompt, verbose=True)

chain.run({})
```


Original text: 
Added SmartLLM wrapper around LLMs to allow for SmartGPT workflow (as in
https://youtu.be/wVzuvf9D9BU). SmartLLM can be used wherever LLM can be
used. E.g:

```
smart_llm = SmartLLM(llm=OpenAI())
smart_llm("What would be a good company name for a company that makes colorful socks?")
```
or
```
smart_llm = SmartLLM(llm=OpenAI())
prompt = PromptTemplate(
    input_variables=["product"],
    template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=smart_llm, prompt=prompt)
chain.run("colorful socks")
```

SmartGPT consists of 3 steps:

1. Ideate - generate n possible solutions ("ideas") to user prompt
2. Critique - find flaws in every idea & select best one
3. Resolve - improve upon best idea & return it

Fixes #4463

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:

- @hwchase17
- @agola11

Twitter: [@UmerHAdil](https://twitter.com/@UmerHAdil) | Discord:
RicChilligerDude#7589

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/9020/head^2
UmerHA 1 year ago committed by GitHub
parent 1d3735a84c
commit 8aab39e3ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,281 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "9e9b7651",
"metadata": {},
"source": [
"# How to use a SmartLLMChain\n",
"\n",
"A SmartLLMChain is a form of self-critique chain that can help you if have particularly complex questions to answer. Instead of doing a single LLM pass, it instead performs these 3 steps:\n",
"1. Ideation: Pass the user prompt n times through the LLM to get n output proposals (called \"ideas\"), where n is a parameter you can set \n",
"2. Critique: The LLM critiques all ideas to find possible flaws and picks the best one \n",
"3. Resolve: The LLM tries to improve upon the best idea (as chosen in the critique step) and outputs it. This is then the final output.\n",
"\n",
"SmartLLMChains are based on the SmartGPT workflow proposed in https://youtu.be/wVzuvf9D9BU.\n",
"\n",
"Note that SmartLLMChains\n",
"- use more LLM passes (ie n+2 instead of just 1)\n",
"- only work then the underlying LLM has the capability for reflection, whicher smaller models often don't\n",
"- only work with underlying models that return exactly 1 output, not multiple\n",
"\n",
"This notebook demonstrates how to use a SmartLLMChain."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "714dede0",
"metadata": {},
"source": [
"##### Same LLM for all steps"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "d3f7fb22",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"...\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "10e5ece6",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain_experimental.smart_llm import SmartLLMChain"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1780da51",
"metadata": {},
"source": [
"As example question, we will use \"I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\". This is an example from the original SmartGPT video (https://youtu.be/wVzuvf9D9BU?t=384). While this seems like a very easy question, LLMs struggle do these kinds of questions that involve numbers and physical reasoning.\n",
"\n",
"As we will see, all 3 initial ideas are completely wrong - even though we're using GPT4! Only when using self-reflection do we get a correct answer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "054af6b1",
"metadata": {},
"outputs": [],
"source": [
"hard_question = \"I have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8049cecd",
"metadata": {},
"source": [
"So, we first create an LLM and prompt template"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "811ea8e1",
"metadata": {},
"outputs": [],
"source": [
"prompt = PromptTemplate.from_template(hard_question)\n",
"llm = ChatOpenAI(temperature=0, model_name=\"gpt-4\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "50b602e4",
"metadata": {},
"source": [
"Now we can create a SmartLLMChain"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8cd49199",
"metadata": {},
"outputs": [],
"source": [
"chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=3, verbose=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6a72f276",
"metadata": {},
"source": [
"Now we can use the SmartLLM as a drop-in replacement for our LLM. E.g.:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "074e5e75",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SmartLLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mI have a 12 liter jug and a 6 liter jug. I want to measure 6 liters. How do I do it?\u001b[0m\n",
"Idea 1:\n",
"\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n",
"2. Pour the water from the 6-liter jug into the 12-liter jug.\n",
"3. Fill the 6-liter jug again.\n",
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n",
"5. The amount of water left in the 6-liter jug will be exactly 6 liters.\u001b[0m\n",
"Idea 2:\n",
"\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n",
"2. Pour the water from the 6-liter jug into the 12-liter jug.\n",
"3. Fill the 6-liter jug again.\n",
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n",
"5. Since the 12-liter jug is now full, there will be 2 liters of water left in the 6-liter jug.\n",
"6. Empty the 12-liter jug.\n",
"7. Pour the 2 liters of water from the 6-liter jug into the 12-liter jug.\n",
"8. Fill the 6-liter jug completely again.\n",
"9. Pour the water from the 6-liter jug into the 12-liter jug, which already has 2 liters in it.\n",
"10. Now, the 12-liter jug will have exactly 6 liters of water (2 liters from before + 4 liters from the 6-liter jug).\u001b[0m\n",
"Idea 3:\n",
"\u001b[36;1m\u001b[1;3m1. Fill the 6-liter jug completely.\n",
"2. Pour the water from the 6-liter jug into the 12-liter jug.\n",
"3. Fill the 6-liter jug again.\n",
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full.\n",
"5. The amount of water left in the 6-liter jug will be exactly 6 liters.\u001b[0m\n",
"Critique:\n",
"\u001b[33;1m\u001b[1;3mIdea 1:\n",
"1. Fill the 6-liter jug completely. (No flaw)\n",
"2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n",
"3. Fill the 6-liter jug again. (No flaw)\n",
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n",
"5. The amount of water left in the 6-liter jug will be exactly 6 liters. (Flaw: This statement is incorrect, as there will be no water left in the 6-liter jug after pouring it into the 12-liter jug.)\n",
"\n",
"Idea 2:\n",
"1. Fill the 6-liter jug completely. (No flaw)\n",
"2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n",
"3. Fill the 6-liter jug again. (No flaw)\n",
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n",
"5. Since the 12-liter jug is now full, there will be 2 liters of water left in the 6-liter jug. (Flaw: This statement is incorrect, as the 12-liter jug will not be full and there will be no water left in the 6-liter jug.)\n",
"6. Empty the 12-liter jug. (No flaw)\n",
"7. Pour the 2 liters of water from the 6-liter jug into the 12-liter jug. (Flaw: This step is based on the incorrect assumption that there are 2 liters of water left in the 6-liter jug.)\n",
"8. Fill the 6-liter jug completely again. (No flaw)\n",
"9. Pour the water from the 6-liter jug into the 12-liter jug, which already has 2 liters in it. (Flaw: This step is based on the incorrect assumption that there are 2 liters of water in the 12-liter jug.)\n",
"10. Now, the 12-liter jug will have exactly 6 liters of water (2 liters from before + 4 liters from the 6-liter jug). (Flaw: This conclusion is based on the incorrect assumptions made in the previous steps.)\n",
"\n",
"Idea 3:\n",
"1. Fill the 6-liter jug completely. (No flaw)\n",
"2. Pour the water from the 6-liter jug into the 12-liter jug. (No flaw)\n",
"3. Fill the 6-liter jug again. (No flaw)\n",
"4. Carefully pour the water from the 6-liter jug into the 12-liter jug until the 12-liter jug is full. (Flaw: The 12-liter jug will never be full in this step, as it can hold 12 liters and we are only pouring 6 liters into it.)\n",
"5. The amount of water left in the 6-liter jug will be exactly 6 liters. (Flaw: This statement is incorrect, as there will be no water left in the 6-liter jug after pouring it into the 12-liter jug.)\u001b[0m\n",
"Resolution:\n",
"\u001b[32;1m\u001b[1;3m1. Fill the 12-liter jug completely.\n",
"2. Pour the water from the 12-liter jug into the 6-liter jug until the 6-liter jug is full.\n",
"3. The amount of water left in the 12-liter jug will be exactly 6 liters.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'1. Fill the 12-liter jug completely.\\n2. Pour the water from the 12-liter jug into the 6-liter jug until the 6-liter jug is full.\\n3. The amount of water left in the 12-liter jug will be exactly 6 liters.'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.run({})"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "bbfebea1",
"metadata": {},
"source": [
"##### Different LLM for different steps"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5be6ec08",
"metadata": {},
"source": [
"You can also use different LLMs for the different steps by passing `ideation_llm`, `critique_llm` and `resolve_llm`. You might want to do this to use a more creative (i.e., high-temperature) model for ideation and a more strict (i.e., low-temperature) model for critique and resolution."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9c33fa19",
"metadata": {},
"outputs": [],
"source": [
"chain = SmartLLMChain(\n",
" ideation_llm=ChatOpenAI(temperature=0.9, model_name=\"gpt-4\"),\n",
" llm=ChatOpenAI(\n",
" temperature=0, model_name=\"gpt-4\"\n",
" ), # will be used for critqiue and resolution as no specific llms are given\n",
" prompt=prompt,\n",
" n_ideas=3,\n",
" verbose=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "886c1cc1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,5 @@
"""Generalized implementation of SmartGPT (origin: https://youtu.be/wVzuvf9D9BU)"""
from langchain_experimental.smart_llm.base import SmartLLMChain
__all__ = ["SmartLLMChain"]

@ -0,0 +1,323 @@
"""Chain for applying self-critique using the SmartGPT workflow."""
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import LLMResult, PromptValue
from pydantic import Extra, root_validator
class SmartLLMChain(Chain):
"""
Generalized implementation of SmartGPT (origin: https://youtu.be/wVzuvf9D9BU)
A SmartLLMChain is an LLMChain that instead of simply passing the prompt to the LLM
performs these 3 steps:
1. Ideate: Pass the user prompt to an ideation LLM n_ideas times,
each result is an "idea"
2. Critique: Pass the ideas to a critique LLM which looks for flaws in the ideas
& picks the best one
3. Resolve: Pass the critique to a resolver LLM which improves upon the best idea
& outputs only the (improved version of) the best output
In total, SmartLLMChain pass will use n_ideas+2 LLM calls
Note that SmartLLMChain will only improve results (compared to a basic LLMChain),
when the underlying models have the capability for reflection, which smaller models
often don't.
Finally, a SmartLLMChain assumes that each underlying LLM outputs exactly 1 result.
"""
class SmartLLMChainHistory:
question: str = ""
ideas: List[str] = []
critique: str = ""
@property
def n_ideas(self) -> int:
return len(self.ideas)
def ideation_prompt_inputs(self) -> Dict[str, Any]:
return {"question": self.question}
def critique_prompt_inputs(self) -> Dict[str, Any]:
return {
"question": self.question,
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
}
def resolve_prompt_inputs(self) -> Dict[str, Any]:
return {
"question": self.question,
**{f"idea_{i+1}": idea for i, idea in enumerate(self.ideas)},
"critique": self.critique,
}
prompt: BasePromptTemplate
"""Prompt object to use."""
ideation_llm: Optional[BaseLanguageModel] = None
"""LLM to use in ideation step. If None given, 'llm' will be used."""
critique_llm: Optional[BaseLanguageModel] = None
"""LLM to use in critique step. If None given, 'llm' will be used."""
resolver_llm: Optional[BaseLanguageModel] = None
"""LLM to use in resolve step. If None given, 'llm' will be used."""
llm: Optional[BaseLanguageModel] = None
"""LLM to use for each steps, if no specific llm for that step is given. """
n_ideas: int = 3
"""Number of ideas to generate in idea step"""
return_intermediate_steps: bool = False
"""Whether to return ideas and critique, in addition to resolution."""
history: SmartLLMChainHistory = SmartLLMChainHistory()
class Config:
extra = Extra.forbid
@root_validator
@classmethod
def validate_inputs(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure we have an LLM for each step."""
llm = values.get("llm")
ideation_llm = values.get("ideation_llm")
critique_llm = values.get("critique_llm")
resolver_llm = values.get("resolver_llm")
if not llm and not ideation_llm:
raise ValueError(
"Either ideation_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if not llm and not critique_llm:
raise ValueError(
"Either critique_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if not llm and not resolver_llm:
raise ValueError(
"Either resolve_llm or llm needs to be given. Pass llm, "
"if you want to use the same llm for all steps, or pass "
"ideation_llm, critique_llm and resolver_llm if you want "
"to use different llms for each step."
)
if llm and ideation_llm and critique_llm and resolver_llm:
raise ValueError(
"LLMs are given for each step (ideation_llm, critique_llm,"
" resolver_llm), but backup LLM (llm) is also given, which"
" would not be used."
)
return values
@property
def input_keys(self) -> List[str]:
"""Defines the input keys."""
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Defines the output keys."""
if self.return_intermediate_steps:
return ["ideas", "critique", "resolution"]
return ["resolution"]
def prep_prompts(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[PromptValue, Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
if "stop" in inputs:
stop = inputs["stop"]
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
return prompt, stop
def _call(
self,
input_list: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
prompt, stop = self.prep_prompts(input_list, run_manager=run_manager)
self.history.question = prompt.to_string()
ideas = self._ideate(stop, run_manager)
self.history.ideas = ideas
critique = self._critique(stop, run_manager)
self.history.critique = critique
resolution = self._resolve(stop, run_manager)
if self.return_intermediate_steps:
return {"ideas": ideas, "critique": critique, "resolution": resolution}
return {"resolution": resolution}
def _get_text_from_llm_result(self, result: LLMResult, step: str) -> str:
"""Between steps, only the LLM result text is passed, not the LLMResult object.
This function extracts the text from an LLMResult."""
if len(result.generations) != 1:
raise ValueError(
f"In SmartLLM the LLM result in step {step} is not "
"exactly 1 element. This should never happen"
)
if len(result.generations[0]) != 1:
raise ValueError(
f"In SmartLLM the LLM in step {step} returned more than "
"1 output. SmartLLM only works with LLMs returning "
"exactly 1 output."
)
return result.generations[0][0].text
def get_prompt_strings(
self, stage: str
) -> List[Tuple[Type[BaseMessagePromptTemplate], str]]:
role_strings: List[Tuple[Type[BaseMessagePromptTemplate], str]] = []
role_strings.append(
(
HumanMessagePromptTemplate,
"Question: {question}\nAnswer: Let's work this out in a step by "
"step way to be sure we have the right answer:",
)
)
if stage == "ideation":
return role_strings
role_strings.extend(
[
*[
(
AIMessagePromptTemplate,
"Idea " + str(i + 1) + ": {idea_" + str(i + 1) + "}",
)
for i in range(self.n_ideas)
],
(
HumanMessagePromptTemplate,
"You are a researcher tasked with investigating the "
f"{self.n_ideas} response options provided. List the flaws and "
"faulty logic of each answer options. Let'w work this out in a step"
" by step way to be sure we have all the errors:",
),
]
)
if stage == "critique":
return role_strings
role_strings.extend(
[
(AIMessagePromptTemplate, "Critique: {critique}"),
(
HumanMessagePromptTemplate,
"You are a resolved tasked with 1) finding which of "
f"the {self.n_ideas} anwer options the researcher thought was "
"best,2) improving that answer and 3) printing the answer in full. "
"Don't output anything for step 1 or 2, only the full answer in 3. "
"Let's work this out in a step by step way to be sure we have "
"the right answer:",
),
]
)
if stage == "resolve":
return role_strings
raise ValueError(
"stage should be either 'ideation', 'critique' or 'resolve',"
f" but it is '{stage}'. This should never happen."
)
def ideation_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("ideation"))
def critique_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("critique"))
def resolve_prompt(self) -> ChatPromptTemplate:
return ChatPromptTemplate.from_strings(self.get_prompt_strings("resolve"))
def _ideate(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> List[str]:
"""Generate n_ideas ideas as response to user prompt."""
llm = self.ideation_llm if self.ideation_llm else self.llm
prompt = self.ideation_prompt().format_prompt(
**self.history.ideation_prompt_inputs()
)
callbacks = run_manager.get_child() if run_manager else None
if llm:
ideas = [
self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks),
step="ideate",
)
for _ in range(self.n_ideas)
]
for i, idea in enumerate(ideas):
_colored_text = get_colored_text(idea, "blue")
_text = f"Idea {i+1}:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return ideas
else:
raise ValueError("llm is none, which should never happen")
def _critique(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> str:
"""Critique each of the ideas from ideation stage & select best one."""
llm = self.critique_llm if self.critique_llm else self.llm
prompt = self.critique_prompt().format_prompt(
**self.history.critique_prompt_inputs()
)
callbacks = run_manager.handlers if run_manager else None
if llm:
critique = self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks), step="critique"
)
_colored_text = get_colored_text(critique, "yellow")
_text = "Critique:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return critique
else:
raise ValueError("llm is none, which should never happen")
def _resolve(
self,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> str:
"""Improve upon the best idea as chosen in critique step & return it."""
llm = self.resolver_llm if self.resolver_llm else self.llm
prompt = self.resolve_prompt().format_prompt(
**self.history.resolve_prompt_inputs()
)
callbacks = run_manager.handlers if run_manager else None
if llm:
resolution = self._get_text_from_llm_result(
llm.generate_prompt([prompt], stop, callbacks), step="resolve"
)
_colored_text = get_colored_text(resolution, "green")
_text = "Resolution:\n" + _colored_text
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
return resolution
else:
raise ValueError("llm is none, which should never happen")

@ -0,0 +1,120 @@
"""Test SmartLLM."""
from langchain.chat_models import FakeListChatModel
from langchain.llms import FakeListLLM
from langchain.prompts.prompt import PromptTemplate
from langchain_experimental.smart_llm import SmartLLMChain
def test_ideation() -> None:
# test that correct responses are returned
responses = ["Idea 1", "Idea 2", "Idea 3"]
llm = FakeListLLM(responses=responses)
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = SmartLLMChain(llm=llm, prompt=prompt)
prompt_value, _ = chain.prep_prompts({"product": "socks"})
chain.history.question = prompt_value.to_string()
results = chain._ideate()
assert results == responses
# test that correct number of responses are returned
for i in range(1, 5):
responses = [f"Idea {j+1}" for j in range(i)]
llm = FakeListLLM(responses=responses)
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=i)
prompt_value, _ = chain.prep_prompts({"product": "socks"})
chain.history.question = prompt_value.to_string()
results = chain._ideate()
assert len(results) == i
def test_critique() -> None:
response = "Test Critique"
llm = FakeListLLM(responses=[response])
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2)
prompt_value, _ = chain.prep_prompts({"product": "socks"})
chain.history.question = prompt_value.to_string()
chain.history.ideas = ["Test Idea 1", "Test Idea 2"]
result = chain._critique()
assert result == response
def test_resolver() -> None:
response = "Test resolution"
llm = FakeListLLM(responses=[response])
prompt = PromptTemplate(
input_variables=["product"],
template="What is a good name for a company that makes {product}?",
)
chain = SmartLLMChain(llm=llm, prompt=prompt, n_ideas=2)
prompt_value, _ = chain.prep_prompts({"product": "socks"})
chain.history.question = prompt_value.to_string()
chain.history.ideas = ["Test Idea 1", "Test Idea 2"]
chain.history.critique = "Test Critique"
result = chain._resolve()
assert result == response
def test_all_steps() -> None:
joke = "Why did the chicken cross the Mobius strip?"
response = "Resolution response"
ideation_llm = FakeListLLM(responses=["Ideation response" for _ in range(20)])
critique_llm = FakeListLLM(responses=["Critique response" for _ in range(20)])
resolver_llm = FakeListLLM(responses=[response for _ in range(20)])
prompt = PromptTemplate(
input_variables=["joke"],
template="Explain this joke to me: {joke}?",
)
chain = SmartLLMChain(
ideation_llm=ideation_llm,
critique_llm=critique_llm,
resolver_llm=resolver_llm,
prompt=prompt,
)
result = chain(joke)
assert result["joke"] == joke
assert result["resolution"] == response
def test_intermediate_output() -> None:
joke = "Why did the chicken cross the Mobius strip?"
llm = FakeListLLM(responses=[f"Response {i+1}" for i in range(5)])
prompt = PromptTemplate(
input_variables=["joke"],
template="Explain this joke to me: {joke}?",
)
chain = SmartLLMChain(llm=llm, prompt=prompt, return_intermediate_steps=True)
result = chain(joke)
assert result["joke"] == joke
assert result["ideas"] == [f"Response {i+1}" for i in range(3)]
assert result["critique"] == "Response 4"
assert result["resolution"] == "Response 5"
def test_all_steps_with_chat_model() -> None:
joke = "Why did the chicken cross the Mobius strip?"
response = "Resolution response"
ideation_llm = FakeListChatModel(responses=["Ideation response" for _ in range(20)])
critique_llm = FakeListChatModel(responses=["Critique response" for _ in range(20)])
resolver_llm = FakeListChatModel(responses=[response for _ in range(20)])
prompt = PromptTemplate(
input_variables=["joke"],
template="Explain this joke to me: {joke}?",
)
chain = SmartLLMChain(
ideation_llm=ideation_llm,
critique_llm=critique_llm,
resolver_llm=resolver_llm,
prompt=prompt,
)
result = chain(joke)
assert result["joke"] == joke
assert result["resolution"] == response
Loading…
Cancel
Save