forked from Archives/langchain
Compare commits
9 Commits
main
...
harrison/u
Author | SHA1 | Date |
---|---|---|
Harrison Chase | cc606180cd | 2 years ago |
Harrison Chase | f423bbc8ac | 2 years ago |
Harrison Chase | bfe50949f5 | 2 years ago |
Harrison Chase | 9966fd0e05 | 2 years ago |
Harrison Chase | 3ef44f41b7 | 2 years ago |
Harrison Chase | a57e74996f | 2 years ago |
Harrison Chase | 67685b874e | 2 years ago |
Harrison Chase | 9b674d3dc6 | 2 years ago |
Harrison Chase | c09fe1dfdf | 2 years ago |
@ -0,0 +1,290 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "4c475754",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.prompts.base import BaseOutputParser\n",
|
||||
"from langchain import OpenAI, LLMChain\n",
|
||||
"from langchain.chains.llm_for_loop import LLMForLoopChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "efcdd239",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# First we make a chain that generates the list"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "2b1884f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Optional\n",
|
||||
"import re\n",
|
||||
"class ListOutputParser(BaseOutputParser):\n",
|
||||
" \n",
|
||||
" def __init__(self, regex: Optional[str] = None):\n",
|
||||
" self.regex=regex\n",
|
||||
" \n",
|
||||
" def parse(self, text: str) -> list:\n",
|
||||
" splits = [t for t in text.split(\"\\n\") if t]\n",
|
||||
" if self.regex is not None:\n",
|
||||
" splits = [re.match(self.regex, s).group(1) for s in splits]\n",
|
||||
" return splits"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "b2b7f8fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"You are a list maker. Your job is make lists given a certain user input.\n",
|
||||
"\n",
|
||||
"The format of your lists should be:\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"List:\n",
|
||||
"- Item 1\n",
|
||||
"- Item 2\n",
|
||||
"...\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Begin!:\n",
|
||||
"\n",
|
||||
"User input: {input}\n",
|
||||
"List:\"\"\"\n",
|
||||
"output_parser = ListOutputParser(regex=\"- (.*)\")\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"input\"], output_parser=output_parser)\n",
|
||||
"\n",
|
||||
"chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"id": "2f8ea6ba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['Tesla', 'Nissan', 'BMW', 'BYD', 'Volkswagen']"
|
||||
]
|
||||
},
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.predict_and_parse(input=\"top 5 ev companies\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "1fdfc7cb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Next we generate the chain that we run over each item"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"id": "0b8f115a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"For the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: {company}\n",
|
||||
"Explanation of their name:\"\"\"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"company\"])\n",
|
||||
"\n",
|
||||
"explanation_chain = LLMChain(llm=OpenAI(), prompt=prompt, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "6d636881",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Tesla\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n\\nTesla is a company that specializes in electric cars and renewable energy. The company is named after Nikola Tesla, a Serbian-American inventor and electrical engineer who was born in the 19th century.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"explanation_chain.predict(company=\"Tesla\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"id": "3c236dd3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Now we combine them"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "941b6389",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for_loop_chain = LLMForLoopChain(llm_chain=chain, apply_chain=explanation_chain)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"id": "98c39dbc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Tesla\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Nissan\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: BMW\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: BYD\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
||||
"Prompt after formatting:\n",
|
||||
"\u001b[32;1m\u001b[1;3mFor the following company, explain the origin of their name:\n",
|
||||
"\n",
|
||||
"Company: Volkswagen\n",
|
||||
"Explanation of their name:\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['\\n\\nTesla was named after the Serbian-American inventor Nikola Tesla, who was known for his work in electricity and magnetism.',\n",
|
||||
" '\\n\\nNissan is a Japanese company, and their name comes from the Japanese word for \"sun.\"',\n",
|
||||
" \"\\n\\nThe company's name is an abbreviation for Bayerische Motoren Werke, which is German for Bavarian Motor Works.\",\n",
|
||||
" '\\n\\nThe company\\'s name is derived from the Chinese characters \"Baiyu Dong\", which literally mean \"to catch the rain in the east\". The name is a reference to the company\\'s origins in the city of Shenzhen, in southeastern China.',\n",
|
||||
" '\\n\\nVolkswagen is a German car company. The word \"Volkswagen\" means \"people\\'s car\" in German.']"
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for_loop_chain.run_list(input=\"top 5 ev companies\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5a2c1803",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
"""Chain that first uses an LLM to generate multiple items then loops over them."""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.base import ListOutputParser
|
||||
|
||||
|
||||
class LLMForLoopChain(Chain, BaseModel):
|
||||
"""Chain that first uses an LLM to generate multiple items then loops over them."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM chain to use to generate multiple items."""
|
||||
apply_chain: Chain
|
||||
"""Chain to apply to each item that is generated."""
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.llm_chain.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
@root_validator()
|
||||
def validate_output_parser(cls, values: Dict) -> Dict:
|
||||
"""Validate that the correct inputs exist for all chains."""
|
||||
chain = values["llm_chain"]
|
||||
if not isinstance(chain.prompt.output_parser, ListOutputParser):
|
||||
raise ValueError(
|
||||
f"The OutputParser on the base prompt should be of type "
|
||||
f"ListOutputParser, got {type(chain.prompt.output_parser)}"
|
||||
)
|
||||
return values
|
||||
|
||||
def run_list(self, **kwargs: Any) -> List[str]:
|
||||
"""Get list from LLM chain and then run chain on each item."""
|
||||
output_items = self.llm_chain.predict_and_parse(**kwargs)
|
||||
res = []
|
||||
for item in output_items:
|
||||
res.append(self.apply_chain.run(item))
|
||||
return res
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
res = self.run_list(**inputs)
|
||||
return {self.output_key: "\n\n".join(res)}
|
Loading…
Reference in New Issue