add llm for loop

harrison/use_output_parser
Harrison Chase 2 years ago
parent a57e74996f
commit 3ef44f41b7

@ -0,0 +1,290 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"id": "4c475754",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"from langchain.prompts.base import BaseOutputParser\n",
"from langchain import OpenAI, LLMChain\n",
"from langchain.chains.llm_for_loop import LLMForLoopChain"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "efcdd239",
"metadata": {},
"outputs": [],
"source": [
"# First we make a chain that generates the list"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2b1884f5",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"import re\n",
"class ListOutputParser(BaseOutputParser):\n",
" \n",
" def __init__(self, regex: Optional[str] = None):\n",
" self.regex=regex\n",
" \n",
" def parse(self, text: str) -> list:\n",
" splits = [t for t in text.split(\"\\n\") if t]\n",
" if self.regex is not None:\n",
" splits = [re.match(self.regex, s).group(1) for s in splits]\n",
" return splits"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b2b7f8fa",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"You are a list maker. Your job is make lists given a certain user input.\n",
"\n",
"The format of your lists should be:\n",
"\n",
"```\n",
"List:\n",
"- Item 1\n",
"- Item 2\n",
"...\n",
"```\n",
"\n",
"Begin!:\n",
"\n",
"User input: {input}\n",
"List:\"\"\"\n",
"output_parser = ListOutputParser(regex=\"- (.*)\")\n",
"prompt = PromptTemplate(template=template, input_variables=[\"input\"], output_parser=output_parser)\n",
"\n",
"chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "2f8ea6ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Tesla', 'Nissan', 'BMW', 'BYD', 'Volkswagen']"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain.predict_and_parse(input=\"top 5 ev companies\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "1fdfc7cb",
"metadata": {},
"outputs": [],
"source": [
"# Next we generate the chain that we run over each item"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "0b8f115a",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"For the following company, explain the origin of their name:\n",
"\n",
"Company: {company}\n",
"Explanation of their name:\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"company\"])\n",
"\n",
"explanation_chain = LLMChain(llm=OpenAI(), prompt=prompt, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "6d636881",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Tesla\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'\\n\\nTesla is a company that specializes in electric cars and renewable energy. The company is named after Nikola Tesla, a Serbian-American inventor and electrical engineer who was born in the 19th century.'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"explanation_chain.predict(company=\"Tesla\")"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "3c236dd3",
"metadata": {},
"outputs": [],
"source": [
"# Now we combine them"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "941b6389",
"metadata": {},
"outputs": [],
"source": [
"for_loop_chain = LLMForLoopChain(llm_chain=chain, apply_chain=explanation_chain)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "98c39dbc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Tesla\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Nissan\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: BMW\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: BYD\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
"\n",
"Company: Volkswagen\n",
"Explanation of their name:\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"['\\n\\nTesla was named after the Serbian-American inventor Nikola Tesla, who was known for his work in electricity and magnetism.',\n",
" '\\n\\nNissan is a Japanese company, and their name comes from the Japanese word for \"sun.\"',\n",
" \"\\n\\nThe company's name is an abbreviation for Bayerische Motoren Werke, which is German for Bavarian Motor Works.\",\n",
" '\\n\\nThe company\\'s name is derived from the Chinese characters \"Baiyu Dong\", which literally mean \"to catch the rain in the east\". The name is a reference to the company\\'s origins in the city of Shenzhen, in southeastern China.',\n",
" '\\n\\nVolkswagen is a German car company. The word \"Volkswagen\" means \"people\\'s car\" in German.']"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"for_loop_chain.run_list(input=\"top 5 ev companies\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a2c1803",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,5 +1,5 @@
"""Chain that just formats a prompt and calls an LLM."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
from pydantic import BaseModel, Extra
@ -78,3 +78,11 @@ class LLMChain(Chain, BaseModel):
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result

@ -0,0 +1,51 @@
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
class LLMForLoopChain(Chain, BaseModel):
"""Chain that first uses an LLM to generate multiple items then loops over them."""
llm_chain: LLMChain
"""LLM chain to use to generate multiple items."""
apply_chain: Chain
"""Chain to apply to each item that is generated."""
output_key: str = "text" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return self.llm_chain.prompt.input_variables
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
def run_list(self, **kwargs: Any) -> List[str]:
"""Get list from LLM chain and then run chain on each item."""
output_items = self.llm_chain.predict_and_parse(**kwargs)
res = []
for item in output_items:
res.append(self.apply_chain.run(item))
return res
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
res = self.run_list(**inputs)
return {self.output_key: "\n\n".join(res)}

@ -1,8 +1,8 @@
"""BasePrompt schema definition."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Union, Optional
from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, root_validator, Extra
from langchain.formatting import formatter
@ -29,7 +29,7 @@ def check_valid_template(
raise ValueError("Invalid prompt schema.")
class OutputParser(ABC):
class BaseOutputParser(ABC):
"""Class to parse the output of an LLM call."""
@abstractmethod
@ -37,22 +37,21 @@ class OutputParser(ABC):
"""Parse the output of an LLM call."""
class DefaultParser(OutputParser):
"""Just return the text."""
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
"""Parse the output of an LLM call."""
return text
class BasePromptTemplate(BaseModel, ABC):
"""Base prompt should expose the format method, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
output_parser: OutputParser = Field(default_factory=DefaultParser)
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not restricted names."""

Loading…
Cancel
Save