forked from Archives/langchain
209 lines
5.3 KiB
Plaintext
209 lines
5.3 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "e58f4d5a",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Agent\n",
|
||
|
"This notebook covers how to create a custom agent for a chat model. It will utilize chat specific prompts."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "5268c7fa",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain.agents import ZeroShotAgent, Tool, AgentExecutor\n",
|
||
|
"from langchain.chains import LLMChain\n",
|
||
|
"from langchain.utilities import SerpAPIWrapper"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "fbaa4dbe",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"search = SerpAPIWrapper()\n",
|
||
|
"tools = [\n",
|
||
|
" Tool(\n",
|
||
|
" name = \"Search\",\n",
|
||
|
" func=search.run,\n",
|
||
|
" description=\"useful for when you need to answer questions about current events\"\n",
|
||
|
" )\n",
|
||
|
"]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "f3ba6f08",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"prefix = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\"\"\"\n",
|
||
|
"suffix = \"\"\"Begin! Remember to speak as a pirate when giving your final answer. Use lots of \"Args\"\"\"\n",
|
||
|
"\n",
|
||
|
"prompt = ZeroShotAgent.create_prompt(\n",
|
||
|
" tools, \n",
|
||
|
" prefix=prefix, \n",
|
||
|
" suffix=suffix, \n",
|
||
|
" input_variables=[]\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "3547a37d",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||
|
"from langchain.prompts.chat import (\n",
|
||
|
" ChatPromptTemplate,\n",
|
||
|
" SystemMessagePromptTemplate,\n",
|
||
|
" AIMessagePromptTemplate,\n",
|
||
|
" HumanMessagePromptTemplate,\n",
|
||
|
")\n",
|
||
|
"from langchain.schema import (\n",
|
||
|
" AIMessage,\n",
|
||
|
" HumanMessage,\n",
|
||
|
" SystemMessage\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "a78f886f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"messages = [\n",
|
||
|
" SystemMessagePromptTemplate(prompt=prompt),\n",
|
||
|
" HumanMessagePromptTemplate.from_template(\"{input}\\n\\nThis was your previous work \"\n",
|
||
|
" f\"(but I haven't seen any of it! I only see what \"\n",
|
||
|
" \"you return as final answer):\\n{agent_scratchpad}\")\n",
|
||
|
"]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "dadadd70",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"prompt = ChatPromptTemplate.from_messages(messages)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "b7180182",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"llm_chain = LLMChain(llm=ChatOpenAI(temperature=0), prompt=prompt)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "ddddb07b",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"tool_names = [tool.name for tool in tools]\n",
|
||
|
"agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "36aef054",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "33a4d6cc",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
|
"\u001b[32;1m\u001b[1;3mArrr, ye be in luck, matey! I'll find ye the answer to yer question.\n",
|
||
|
"\n",
|
||
|
"Thought: I need to search for the current population of Canada.\n",
|
||
|
"Action: Search\n",
|
||
|
"Action Input: \"current population of Canada 2023\"\n",
|
||
|
"\u001b[0m\n",
|
||
|
"Observation: \u001b[36;1m\u001b[1;3mThe current population of Canada is 38,623,091 as of Saturday, March 4, 2023, based on Worldometer elaboration of the latest United Nations data.\u001b[0m\n",
|
||
|
"Thought:\u001b[32;1m\u001b[1;3mAhoy, me hearties! I've found the answer to yer question.\n",
|
||
|
"\n",
|
||
|
"Final Answer: As of March 4, 2023, the population of Canada be 38,623,091. Arrr!\u001b[0m\n",
|
||
|
"\n",
|
||
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'As of March 4, 2023, the population of Canada be 38,623,091. Arrr!'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"agent_executor.run(\"How many people live in canada as of 2023?\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "6aefe978",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.1"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|