Callbacks Refactor [base] (#3256)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com>
Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Ankush Gola 2023-04-30 11:14:09 -07:00 committed by GitHub
parent 18ec22fe56
commit d3ec00b566
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
208 changed files with 6394 additions and 3353 deletions

View File

@ -61,7 +61,6 @@
"from datetime import datetime\n",
"\n",
"from langchain.llms import OpenAI\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks import AimCallbackHandler, StdOutCallbackHandler"
]
},
@ -109,8 +108,8 @@
" experiment_name=\"scenario 1: OpenAI LLM\",\n",
")\n",
"\n",
"manager = CallbackManager([StdOutCallbackHandler(), aim_callback])\n",
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
"callbacks = [StdOutCallbackHandler(), aim_callback]\n",
"llm = OpenAI(temperature=0, callbacks=callbacks)"
]
},
{
@ -177,7 +176,7 @@
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
"\n",
"test_prompts = [\n",
" {\"title\": \"documentary about good video games that push the boundary of game design\"},\n",
@ -249,13 +248,12 @@
],
"source": [
"# scenario 3 - Agent with Tools\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
"agent = initialize_agent(\n",
" tools,\n",
" llm,\n",
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
" callback_manager=manager,\n",
" verbose=True,\n",
" callbacks=callbacks,\n",
")\n",
"agent.run(\n",
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",

View File

@ -79,7 +79,6 @@
"source": [
"from datetime import datetime\n",
"from langchain.callbacks import ClearMLCallbackHandler, StdOutCallbackHandler\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.llms import OpenAI\n",
"\n",
"# Setup and use the ClearML Callback\n",
@ -93,9 +92,9 @@
" complexity_metrics=True,\n",
" stream_logs=True\n",
")\n",
"manager = CallbackManager([StdOutCallbackHandler(), clearml_callback])\n",
"callbacks = [StdOutCallbackHandler(), clearml_callback]\n",
"# Get the OpenAI model ready to go\n",
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
"llm = OpenAI(temperature=0, callbacks=callbacks)"
]
},
{
@ -523,13 +522,12 @@
"from langchain.agents import AgentType\n",
"\n",
"# SCENARIO 2 - Agent with Tools\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
"agent = initialize_agent(\n",
" tools,\n",
" llm,\n",
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
" callback_manager=manager,\n",
" verbose=True,\n",
" callbacks=callbacks,\n",
")\n",
"agent.run(\n",
" \"Who is the wife of the person who sang summer of 69?\"\n",

View File

@ -121,7 +121,6 @@
"from datetime import datetime\n",
"\n",
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.llms import OpenAI\n",
"\n",
"comet_callback = CometCallbackHandler(\n",
@ -131,8 +130,8 @@
" tags=[\"llm\"],\n",
" visualizations=[\"dep\"],\n",
")\n",
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
"llm = OpenAI(temperature=0.9, callbacks=callbacks, verbose=True)\n",
"\n",
"llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\", \"Tell me a fact\"] * 3)\n",
"print(\"LLM result\", llm_result)\n",
@ -153,7 +152,6 @@
"outputs": [],
"source": [
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.chains import LLMChain\n",
"from langchain.llms import OpenAI\n",
"from langchain.prompts import PromptTemplate\n",
@ -164,15 +162,14 @@
" stream_logs=True,\n",
" tags=[\"synopsis-chain\"],\n",
")\n",
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
"\n",
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
"\n",
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
"\n",
"test_prompts = [{\"title\": \"Documentary about Bigfoot in Paris\"}]\n",
"print(synopsis_chain.apply(test_prompts))\n",
@ -194,7 +191,6 @@
"source": [
"from langchain.agents import initialize_agent, load_tools\n",
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.llms import OpenAI\n",
"\n",
"comet_callback = CometCallbackHandler(\n",
@ -203,15 +199,15 @@
" stream_logs=True,\n",
" tags=[\"agent\"],\n",
")\n",
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
"llm = OpenAI(temperature=0.9, callbacks=callbacks)\n",
"\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n",
"agent = initialize_agent(\n",
" tools,\n",
" llm,\n",
" agent=\"zero-shot-react-description\",\n",
" callback_manager=manager,\n",
" callbacks=callbacks,\n",
" verbose=True,\n",
")\n",
"agent.run(\n",
@ -255,7 +251,6 @@
"from rouge_score import rouge_scorer\n",
"\n",
"from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.chains import LLMChain\n",
"from langchain.llms import OpenAI\n",
"from langchain.prompts import PromptTemplate\n",
@ -298,10 +293,10 @@
" tags=[\"custom_metrics\"],\n",
" custom_metrics=rouge_score.compute_metric,\n",
")\n",
"manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n",
"llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n",
"callbacks = [StdOutCallbackHandler(), comet_callback]\n",
"llm = OpenAI(temperature=0.9)\n",
"\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)\n",
"\n",
"test_prompts = [\n",
" {\n",
@ -323,7 +318,7 @@
" \"\"\"\n",
" }\n",
"]\n",
"print(synopsis_chain.apply(test_prompts))\n",
"print(synopsis_chain.apply(test_prompts, callbacks=callbacks))\n",
"comet_callback.flush_tracker(synopsis_chain, finish=True)"
]
}

View File

@ -3,6 +3,7 @@
This page covers how to use the `GPT4All` wrapper within LangChain. The tutorial is divided into two parts: installation and setup, followed by usage with an example.
## Installation and Setup
- Install the Python package with `pip install pyllamacpp`
- Download a [GPT4All model](https://github.com/nomic-ai/pyllamacpp#supported-model) and place it in your desired directory
@ -28,16 +29,16 @@ To stream the model's predictions, add in a CallbackManager.
```python
from langchain.llms import GPT4All
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
# There are many CallbackHandlers supported, such as
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, verbose=True)
callbacks = [StreamingStdOutCallbackHandler()]
model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8)
# Generate text. Tokens are streamed through the callback manager.
model("Once upon a time, ")
model("Once upon a time, ", callbacks=callbacks)
```
## Model File

View File

@ -50,7 +50,6 @@
"source": [
"from datetime import datetime\n",
"from langchain.callbacks import WandbCallbackHandler, StdOutCallbackHandler\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.llms import OpenAI"
]
},
@ -196,8 +195,8 @@
" name=\"llm\",\n",
" tags=[\"test\"],\n",
")\n",
"manager = CallbackManager([StdOutCallbackHandler(), wandb_callback])\n",
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)"
"callbacks = [StdOutCallbackHandler(), wandb_callback]\n",
"llm = OpenAI(temperature=0, callbacks=callbacks)"
]
},
{
@ -484,7 +483,7 @@
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n",
"\n",
"test_prompts = [\n",
" {\n",
@ -577,16 +576,15 @@
],
"source": [
"# SCENARIO 3 - Agent with Tools\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n",
"tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n",
"agent = initialize_agent(\n",
" tools,\n",
" llm,\n",
" agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n",
" callback_manager=manager,\n",
" verbose=True,\n",
")\n",
"agent.run(\n",
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n",
" \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\",\n",
" callbacks=callbacks,\n",
")\n",
"wandb_callback.flush_tracker(agent, reset=False, finish=True)"
]

View File

@ -44,6 +44,8 @@ These modules are, in increasing order of complexity:
- `Agents <./modules/agents.html>`_: Agents involve an LLM making decisions about which Actions to take, taking that Action, seeing an Observation, and repeating that until done. LangChain provides a standard interface for agents, a selection of agents to choose from, and examples of end to end agents.
- `Callbacks <./modules/callbacks/getting_started.html>`_: It can be difficult to track all that occurs inside a chain or agent - callbacks help add a level of observability and introspection.
.. toctree::
:maxdepth: 1
@ -57,6 +59,7 @@ These modules are, in increasing order of complexity:
./modules/memory.md
./modules/chains.md
./modules/agents.md
./modules/callbacks/getting_started.ipynb
Use Cases
----------

View File

@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"id": "da5df06c-af6f-4572-b9f5-0ab971c16487",
"metadata": {
"tags": []
@ -42,7 +42,6 @@
"from langchain.agents import AgentType\n",
"from langchain.llms import OpenAI\n",
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.tracers import LangChainTracer\n",
"from aiohttp import ClientSession\n",
"\n",
@ -57,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 8,
"id": "fd4c294e-b1d6-44b8-b32e-2765c017e503",
"metadata": {
"tags": []
@ -73,16 +72,15 @@
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
"Action: Search\n",
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Rafael Nadal's age\n",
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 75, 63, 57, 46, 64 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Action: Search\n",
"Action Input: \"Rafael Nadal age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 36 raised to the 0.334 power\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n",
"Action: Calculator\n",
"Action Input: 36^0.334\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n",
"\n",
@ -93,18 +91,17 @@
"\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mJason Sudeikis\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Jason Sudeikis' age\n",
"Observation: \u001b[33;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
"Action: Search\n",
"Action Input: \"Jason Sudeikis age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m47 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 47 raised to the 0.23 power\n",
"Action Input: \"Harry Styles age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
"Action: Calculator\n",
"Action Input: 47^0.23\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.4242784855673896\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Jason Sudeikis, Olivia Wilde's boyfriend, is 47 years old and his age raised to the 0.23 power is 2.4242784855673896.\u001b[0m\n",
"Action Input: 29^0.23\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: Harry Styles, Olivia Wilde's boyfriend, is 29 years old and his age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
@ -113,17 +110,17 @@
"\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mMax Verstappen\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Max Verstappen's age\n",
"Observation: \u001b[33;1m\u001b[1;3mMichael Schumacher (top left) and Lewis Hamilton (top right) have each won the championship a record seven times during their careers, while Sebastian Vettel ( ...\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Action: Search\n",
"Action Input: \"Max Verstappen Age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m25 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.23 power\n",
"Action Input: \"Michael Schumacher age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m54 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.23 power\n",
"Action: Calculator\n",
"Action Input: 25^0.23\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.84599359907945\u001b[0m\n",
"Action Input: 54^0.23\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.502940725307012\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Max Verstappen, 25 years old, raised to the 0.23 power is 1.84599359907945.\u001b[0m\n",
"Final Answer: Michael Schumacher, aged 54, raised to the 0.23 power is 2.502940725307012.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
@ -132,18 +129,17 @@
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n",
"Action: Search\n",
"Action Input: \"US Open women's final 2019 winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu defeated Serena Williams in the final, 63, 75 to win the women's singles tennis title at the 2019 US Open. It was her first major title, and she became the first Canadian, as well as the first player born in the 2000s, to win a major singles title.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Bianca Andreescu's age.\n",
"Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out her age\n",
"Action: Search\n",
"Action Input: \"Bianca Andreescu age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m22 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the age of Bianca Andreescu and can calculate her age raised to the 0.34 power.\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.34 power\n",
"Action: Calculator\n",
"Action Input: 22^0.34\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: Bianca Andreescu won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.8603798598506933.\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Bianca Andreescu, aged 22, won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.86.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
@ -160,35 +156,32 @@
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n",
"Action: Calculator\n",
"Action Input: 53^0.19\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Jay-Z is Beyonce's husband and his age raised to the 0.19 power is 2.12624064206896.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"Serial executed in 65.11 seconds.\n"
"Serial executed in 52.47 seconds.\n"
]
}
],
"source": [
"def generate_serially():\n",
" for q in questions:\n",
"llm = OpenAI(temperature=0)\n",
"tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n",
"agent = initialize_agent(\n",
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n",
")\n",
" agent.run(q)\n",
"\n",
"s = time.perf_counter()\n",
"generate_serially()\n",
"for q in questions:\n",
" agent.run(q)\n",
"elapsed = time.perf_counter() - s\n",
"print(f\"Serial executed in {elapsed:0.2f} seconds.\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"id": "076d7b85-45ec-465d-8b31-c2ad119c3438",
"metadata": {
"tags": []
@ -202,6 +195,9 @@
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
@ -210,182 +206,94 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Beyonce's husband is and then calculate his age raised to the 0.19 power.\n",
"Action: Search\n",
"Action Input: \"Who is Beyonce's husband?\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mJay-Z\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n",
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n",
"Action: Search\n",
"Action Input: \"US Open women's final 2019 winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mJason Sudeikis\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[33;1m\u001b[1;3mMax Verstappen\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu defeated Serena Williams in the final, 63, 75 to win the women's singles tennis title at the 2019 US Open. It was her first major title, and she became the first Canadian, as well as the first player born in the 2000s, to win a major singles title.\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Jason Sudeikis' age\n",
"Action: Search\n",
"Action Input: \"Jason Sudeikis age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Jay-Z's age\n",
"Action: Search\n",
"Action Input: \"How old is Jay-Z?\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m53 years\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
"Action: Search\n",
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 75, 63, 57, 46, 64 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n",
"Action: Search\n",
"Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Beyonce's husband is and then calculate his age raised to the 0.19 power.\n",
"Action: Search\n",
"Action Input: \"Who is Beyonce's husband?\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[33;1m\u001b[1;3m47 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Max Verstappen's age\n",
"Observation: \u001b[33;1m\u001b[1;3mJay-Z\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[33;1m\u001b[1;3mMichael Schumacher (top left) and Lewis Hamilton (top right) have each won the championship a record seven times during their careers, while Sebastian Vettel ( ...\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out her age\n",
"Action: Search\n",
"Action Input: \"Max Verstappen Age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m25 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Bianca Andreescu's age.\n",
"Action Input: \"Bianca Andreescu age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Jay-Z's age\n",
"Action: Search\n",
"Action Input: \"Bianca Andreescu age\"\u001b[0m\n",
"Action Input: \"How old is Jay-Z?\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m53 years\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[33;1m\u001b[1;3m22 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n",
"Action: Calculator\n",
"Action Input: 53^0.19\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n",
"Action: Search\n",
"Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 47 raised to the 0.23 power\n",
"Action Input: \"Harry Styles age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.34 power\n",
"Action: Calculator\n",
"Action Input: 47^0.23\u001b[0m\n",
"Action Input: 22^0.34\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n",
"Action: Calculator\n",
"Action Input: 53^0.19\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n",
"Action: Calculator\n",
"Action Input: 29^0.23\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Action: Search\n",
"Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n",
"Action: Search\n",
"Action Input: \"Michael Schumacher age\"\u001b[0m\n",
"Observation: \n",
"Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.23 power\n",
"Action: Calculator\n",
"Action Input: 25^0.23\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the age of Bianca Andreescu and can calculate her age raised to the 0.34 power.\n",
"Action: Calculator\n",
"Action Input: 22^0.34\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.84599359907945\u001b[0m\n",
"Thought:\u001b[33;1m\u001b[1;3m54 years\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.4242784855673896\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n",
"Action: Calculator\n",
"Action Input: 36^0.334\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Jay-Z is Beyonce's husband and his age raised to the 0.19 power is 2.12624064206896.\u001b[0m\n",
"\n",
"Action Input: 36^0.334\u001b[0m\u001b[32;1m\u001b[1;3m I now need to calculate the age raised to the 0.23 power\n",
"Action: Calculator\n",
"Action Input: 54^0.23\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Max Verstappen, 25 years old, raised to the 0.23 power is 1.84599359907945.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Jason Sudeikis, Olivia Wilde's boyfriend, is 47 years old and his age raised to the 0.23 power is 2.4242784855673896.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.502940725307012\u001b[0m\n",
"Thought:\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: Bianca Andreescu won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.8603798598506933.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"Concurrent executed in 12.38 seconds.\n"
"Concurrent executed in 14.49 seconds.\n"
]
}
],
"source": [
"async def generate_concurrently():\n",
" agents = []\n",
" # To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n",
" # but you must manually close the client session at the end of your program/event loop\n",
" aiosession = ClientSession()\n",
" for _ in questions:\n",
" manager = CallbackManager([StdOutCallbackHandler()])\n",
" llm = OpenAI(temperature=0, callback_manager=manager)\n",
" async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n",
" agents.append(\n",
" initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
"llm = OpenAI(temperature=0)\n",
"tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n",
"agent = initialize_agent(\n",
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n",
")\n",
" tasks = [async_agent.arun(q) for async_agent, q in zip(agents, questions)]\n",
" await asyncio.gather(*tasks)\n",
" await aiosession.close()\n",
"\n",
"s = time.perf_counter()\n",
"# If running this outside of Jupyter, use asyncio.run(generate_concurrently())\n",
"await generate_concurrently()\n",
"# If running this outside of Jupyter, use asyncio.run or loop.run_until_complete\n",
"tasks = [agent.arun(q) for q in questions]\n",
"await asyncio.gather(*tasks)\n",
"elapsed = time.perf_counter() - s\n",
"print(f\"Concurrent executed in {elapsed:0.2f} seconds.\")"
]
},
{
"cell_type": "markdown",
"id": "97ef285c-4a43-4a4e-9698-cd52a1bc56c9",
"metadata": {},
"source": [
"## Using Tracing with Asynchronous Agents\n",
"\n",
"To use tracing with async agents, you must pass in a custom `CallbackManager` with `LangChainTracer` to each agent running asynchronously. This way, you avoid collisions while the trace is being collected."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "44bda05a-d33e-4e91-9a71-a0f3f96aae95",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n",
"Action: Search\n",
"Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3mRafael Nadal\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to find out Rafael Nadal's age\n",
"Action: Search\n",
"Action Input: \"Rafael Nadal age\"\u001b[0m\n",
"Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate 36 raised to the 0.334 power\n",
"Action: Calculator\n",
"Action Input: 36^0.334\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n",
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
}
],
"source": [
"# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n",
"# but you must manually close the client session at the end of your program/event loop\n",
"aiosession = ClientSession()\n",
"tracer = LangChainTracer()\n",
"tracer.load_default_session()\n",
"manager = CallbackManager([StdOutCallbackHandler(), tracer])\n",
"\n",
"# Pass the manager into the llm if you want llm calls traced.\n",
"llm = OpenAI(temperature=0, callback_manager=manager)\n",
"\n",
"async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n",
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
"await async_agent.arun(questions[0])\n",
"await aiosession.close()"
]
}
],
"metadata": {
@ -404,7 +312,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@ -373,6 +373,7 @@
"metadata": {},
"outputs": [],
"source": [
"tools = get_tools(\"whats the weather?\")\n",
"tool_names = [tool.name for tool in tools]\n",
"agent = LLMSingleActionAgent(\n",
" llm_chain=llm_chain, \n",

View File

@ -75,6 +75,7 @@
}
],
"source": [
"\n",
"arxiv = ArxivAPIWrapper()\n",
"docs = arxiv.run(\"1605.08386\")\n",
"docs"
@ -163,7 +164,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -69,7 +69,8 @@
}
],
"source": [
"StableDiffusionTool().langchain.run(\"Please create a photo of a dog riding a skateboard\")"
"local_file_path = StableDiffusionTool().langchain.run(\"Please create a photo of a dog riding a skateboard\")\n",
"local_file_path"
]
},
{
@ -89,7 +90,7 @@
"metadata": {},
"outputs": [],
"source": [
"im = Image.open(\"/Users/harrisonchase/workplace/langchain/docs/modules/agents/tools/examples/b61c1dd9-47e2-46f1-a47c-20d27640993d/tmp4ap48vnm.jpg\")"
"im = Image.open(local_file_path)"
]
},
{

View File

@ -19,6 +19,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.agents import Tool\n",
"from langchain.utilities import PythonREPL"
]
},
@ -59,7 +60,14 @@
"id": "54fc1f03",
"metadata": {},
"outputs": [],
"source": []
"source": [
"# You can create the tool to pass to an agent\n",
"repl_tool = Tool(\n",
" name=\"python_repl\",\n",
" description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n",
" func=python_repl\n",
")"
]
}
],
"metadata": {

View File

@ -102,7 +102,15 @@
"id": "e0a1dc1c",
"metadata": {},
"outputs": [],
"source": []
"source": [
"from langchain.agents import Tool\n",
"# You can create the tool to pass to an agent\n",
"repl_tool = Tool(\n",
" name=\"python_repl\",\n",
" description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n",
" func=search.run,\n",
")"
]
}
],
"metadata": {

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@ -37,7 +37,7 @@
"'Hello World\\n'"
]
},
"execution_count": 1,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -50,7 +50,7 @@
"\n",
"text = \"Please write a bash script that prints 'Hello World' to the console.\"\n",
"\n",
"bash_chain = LLMBashChain(llm=llm, verbose=True)\n",
"bash_chain = LLMBashChain.from_llm(llm, verbose=True)\n",
"\n",
"bash_chain.run(text)"
]
@ -65,11 +65,12 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts.prompt import PromptTemplate\n",
"from langchain.chains.llm_bash.prompt import BashOutputParser\n",
"\n",
"_PROMPT_TEMPLATE = \"\"\"If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put \"#!/bin/bash\" in your answer. Make sure to reason step by step, using this format:\n",
"Question: \"copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'\"\n",
@ -88,12 +89,12 @@
"That is the format. Begin!\n",
"Question: {question}\"\"\"\n",
"\n",
"PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE)"
"PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE, output_parser=BashOutputParser())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 11,
"metadata": {},
"outputs": [
{
@ -120,13 +121,13 @@
"'Hello World\\n'"
]
},
"execution_count": 3,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bash_chain = LLMBashChain(llm=llm, prompt=PROMPT, verbose=True)\n",
"bash_chain = LLMBashChain.from_llm(llm, prompt=PROMPT, verbose=True)\n",
"\n",
"text = \"Please write a bash script that prints 'Hello World' to the console.\"\n",
"\n",
@ -134,7 +135,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -145,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 12,
"metadata": {},
"outputs": [
{
@ -177,7 +177,7 @@
"'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'"
]
},
"execution_count": 4,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -187,7 +187,7 @@
"\n",
"\n",
"persistent_process = BashProcess(persistent=True)\n",
"bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n",
"bash_chain = LLMBashChain.from_llm(llm, bash_process=persistent_process, verbose=True)\n",
"\n",
"text = \"List the current directory then move up a level.\"\n",
"\n",
@ -196,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -224,7 +224,7 @@
"'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'"
]
},
"execution_count": 5,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -258,7 +258,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -23,28 +23,16 @@
"\n",
"\n",
"\u001b[1m> Entering new SequentialChain chain...\u001b[0m\n",
"\u001b[1mChain 0\u001b[0m:\n",
"{'statement': '\\nNone. Mammals do not lay eggs.'}\n",
"\n",
"\u001b[1mChain 1\u001b[0m:\n",
"{'assertions': '\\n• Mammals reproduce using live birth\\n• Mammals do not lay eggs\\n• Animals that lay eggs are not mammals'}\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1mChain 2\u001b[0m:\n",
"{'checked_assertions': '\\n1. True\\n\\n2. True\\n\\n3. False - Mammals are a class of animals that includes animals that lay eggs, such as monotremes (platypus and echidna).'}\n",
"\n",
"\u001b[1mChain 3\u001b[0m:\n",
"{'revised_statement': ' Monotremes, such as the platypus and echidna, lay the biggest eggs of any mammal.'}\n",
"\n",
"\n",
"\u001b[1m> Finished SequentialChain chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished LLMCheckerChain chain.\u001b[0m\n"
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' Monotremes, such as the platypus and echidna, lay the biggest eggs of any mammal.'"
"' No mammal lays the biggest eggs. The Elephant Bird, which was a species of giant bird, laid the largest eggs of any bird.'"
]
},
"execution_count": 1,
@ -60,7 +48,7 @@
"\n",
"text = \"What type of mammal lays the biggest eggs?\"\n",
"\n",
"checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n",
"checker_chain = LLMCheckerChain.from_llm(llm, verbose=True)\n",
"\n",
"checker_chain.run(text)"
]
@ -89,7 +77,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"id": "44e9ba31",
"metadata": {},
"outputs": [
@ -24,23 +24,22 @@
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"What is 13 raised to the .3432 power?\u001b[32;1m\u001b[1;3m\n",
"```python\n",
"import math\n",
"print(math.pow(13, .3432))\n",
"```text\n",
"13 ** .3432\n",
"```\n",
"...numexpr.evaluate(\"13 ** .3432\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Answer: 2.4116004626599237\\n'"
"'Answer: 2.4116004626599237'"
]
},
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@ -49,102 +48,7 @@
"from langchain import OpenAI, LLMMathChain\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"llm_math = LLMMathChain(llm=llm, verbose=True)\n",
"\n",
"llm_math.run(\"What is 13 raised to the .3432 power?\")"
]
},
{
"cell_type": "markdown",
"id": "2bdd5fc6",
"metadata": {},
"source": [
"## Customize Prompt\n",
"You can also customize the prompt that is used. Here is an example prompting it to use numpy"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "76be17b0",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"_PROMPT_TEMPLATE = \"\"\"You are GPT-3, and you can't do math.\n",
"\n",
"You can do basic math, and your memorization abilities are impressive, but you can't do any complex calculations that a human could not do in their head. You also have an annoying tendency to just make up highly specific, but wrong, answers.\n",
"\n",
"So we hooked you up to a Python 3 kernel, and now you can execute code. If you execute code, you must print out the final answer using the print function. You MUST use the python package numpy to answer your question. You must import numpy as np.\n",
"\n",
"\n",
"Question: ${{Question with hard calculation.}}\n",
"```python\n",
"${{Code that prints what you need to know}}\n",
"print(${{code}})\n",
"```\n",
"```output\n",
"${{Output of your code}}\n",
"```\n",
"Answer: ${{Answer}}\n",
"\n",
"Begin.\n",
"\n",
"Question: What is 37593 * 67?\n",
"\n",
"```python\n",
"import numpy as np\n",
"print(np.multiply(37593, 67))\n",
"```\n",
"```output\n",
"2518731\n",
"```\n",
"Answer: 2518731\n",
"\n",
"Question: {question}\"\"\"\n",
"\n",
"PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "0c42faa0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"What is 13 raised to the .3432 power?\u001b[32;1m\u001b[1;3m\n",
"\n",
"```python\n",
"import numpy as np\n",
"print(np.power(13, .3432))\n",
"```\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\n",
"\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Answer: 2.4116004626599237\\n'"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_math = LLMMathChain(llm=llm, prompt=PROMPT, verbose=True)\n",
"llm_math = LLMMathChain.from_llm(llm, verbose=True)\n",
"\n",
"llm_math.run(\"What is 13 raised to the .3432 power?\")"
]
@ -152,7 +56,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0c62951b",
"id": "e978bb8e",
"metadata": {},
"outputs": [],
"source": []
@ -174,7 +78,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -221,11 +221,11 @@
"\n",
"• The light from these galaxies has been traveling for over 13 billion years to reach us. - True \n",
"\n",
"• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 1995. \n",
"• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 2004. \n",
"\n",
"• Exoplanets were first discovered in 1992. - True \n",
"\n",
"• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. It is too early to tell as the JWST has not been launched yet.\n",
"• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. The JWST has not yet been launched, so it is not yet known how much detail it will be able to provide.\n",
"\"\"\"\n",
"\n",
"Original Summary:\n",
@ -296,11 +296,11 @@
"\n",
"• The light from these galaxies has been traveling for over 13 billion years to reach us. - True \n",
"\n",
"• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 1995. \n",
"• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 2004. \n",
"\n",
"• Exoplanets were first discovered in 1992. - True \n",
"\n",
"• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. It is too early to tell as the JWST has not been launched yet.\n",
"• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. The JWST has not yet been launched, so it is not yet known how much detail it will be able to provide.\n",
"\"\"\"\n",
"Result:\u001b[0m\n",
"\n",
@ -312,7 +312,7 @@
"Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\n",
"• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\n",
"• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\n",
"• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail than ever before.\n",
"• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail when it is launched in 2023.\n",
"These discoveries can spark a child's imagination about the infinite wonders of the universe.\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
@ -321,7 +321,7 @@
{
"data": {
"text/plain": [
"'Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\\n• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\\n• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\\n• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail than ever before.\\nThese discoveries can spark a child\\'s imagination about the infinite wonders of the universe.'"
"'Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\\n• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\\n• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\\n• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail when it is launched in 2023.\\nThese discoveries can spark a child\\'s imagination about the infinite wonders of the universe.'"
]
},
"execution_count": 1,
@ -334,7 +334,7 @@
"from langchain.llms import OpenAI\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"checker_chain = LLMSummarizationCheckerChain(llm=llm, verbose=True, max_checks=2)\n",
"checker_chain = LLMSummarizationCheckerChain.from_llm(llm, verbose=True, max_checks=2)\n",
"text = \"\"\"\n",
"Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\n",
"• In 2023, The JWST spotted a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\n",
@ -407,7 +407,8 @@
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n",
"\n",
"Checked Assertions:\"\"\"\n",
"Checked Assertions:\n",
"\"\"\"\n",
"\n",
"- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n",
"\n",
@ -428,7 +429,8 @@
"- It is considered the northern branch of the Norwegian Sea. True\n",
"\"\"\"\n",
"\n",
"Original Summary:\"\"\"\n",
"Original Summary:\n",
"\"\"\"\n",
"The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is one of five oceans in the world, alongside the Pacific Ocean, Atlantic Ocean, Indian Ocean, and the Southern Ocean. It is the smallest of the five oceans and is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\n",
"\"\"\"\n",
"\n",
@ -443,7 +445,7 @@
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n",
"\n",
"If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n",
"\n",
@ -555,7 +557,8 @@
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n",
"\n",
"Checked Assertions:\"\"\"\n",
"Checked Assertions:\n",
"\"\"\"\n",
"\n",
"- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n",
"\n",
@ -574,7 +577,8 @@
"- It is considered the northern branch of the Norwegian Sea. False - It is considered the northern branch of the Atlantic Ocean.\n",
"\"\"\"\n",
"\n",
"Original Summary:\"\"\"\n",
"Original Summary:\n",
"\"\"\"\n",
"\n",
"The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is an arm of the Arctic Ocean. It is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\n",
"\"\"\"\n",
@ -583,14 +587,20 @@
"\n",
"The output should have the same structure and formatting as the original summary.\n",
"\n",
"Summary:\u001b[0m\n",
"Summary:\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n",
"\n",
"If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n",
"\n",
@ -701,7 +711,8 @@
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n",
"\n",
"Checked Assertions:\"\"\"\n",
"Checked Assertions:\n",
"\"\"\"\n",
"\n",
"- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n",
"\n",
@ -718,7 +729,8 @@
"- It is considered the northern branch of the Atlantic Ocean. False - The Greenland Sea is considered part of the Arctic Ocean, not the Atlantic Ocean.\n",
"\"\"\"\n",
"\n",
"Original Summary:\"\"\"\n",
"Original Summary:\n",
"\"\"\"\n",
"\n",
"\n",
"The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is an arm of the Arctic Ocean. It is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the country of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Atlantic Ocean.\n",
@ -735,7 +747,7 @@
"\n",
"\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
"Prompt after formatting:\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n",
"\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n",
"\n",
"If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n",
"\n",
@ -813,14 +825,14 @@
"from langchain.llms import OpenAI\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"checker_chain = LLMSummarizationCheckerChain(llm=llm, verbose=True, max_checks=3)\n",
"checker_chain = LLMSummarizationCheckerChain.from_llm(llm, verbose=True, max_checks=3)\n",
"text = \"The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is one of five oceans in the world, alongside the Pacific Ocean, Atlantic Ocean, Indian Ocean, and the Southern Ocean. It is the smallest of the five oceans and is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\"\n",
"checker_chain.run(text)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -1077,7 +1089,7 @@
"'Birds are not mammals, but they are a class of their own. They lay eggs, unlike mammals which give birth to live young.'"
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -1087,17 +1099,10 @@
"from langchain.llms import OpenAI\n",
"\n",
"llm = OpenAI(temperature=0)\n",
"checker_chain = LLMSummarizationCheckerChain(llm=llm, max_checks=3, verbose=True)\n",
"checker_chain = LLMSummarizationCheckerChain.from_llm(llm, max_checks=3, verbose=True)\n",
"text = \"Mammals can lay eggs, birds can lay eggs, therefore birds are mammals.\"\n",
"checker_chain.run(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@ -28,7 +28,7 @@
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)"
"llm = OpenAI(temperature=0, max_tokens=512)"
]
},
{
@ -63,7 +63,9 @@
"cell_type": "code",
"execution_count": 4,
"id": "3ef64b27",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
@ -71,17 +73,17 @@
"text": [
"\n",
"\n",
"\u001B[1m> Entering new PALChain chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3mdef solution():\n",
"\u001b[1m> Entering new PALChain chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mdef solution():\n",
" \"\"\"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\"\"\"\n",
" cindy_pets = 4\n",
" marcia_pets = cindy_pets + 2\n",
" jan_pets = marcia_pets * 3\n",
" total_pets = cindy_pets + marcia_pets + jan_pets\n",
" result = total_pets\n",
" return result\u001B[0m\n",
" return result\u001b[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
@ -139,8 +141,8 @@
"text": [
"\n",
"\n",
"\u001B[1m> Entering new PALChain chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3m# Put objects into a list to record ordering\n",
"\u001b[1m> Entering new PALChain chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n",
"objects = []\n",
"objects += [('booklet', 'blue')] * 2\n",
"objects += [('booklet', 'purple')] * 2\n",
@ -151,9 +153,9 @@
"\n",
"# Count number of purple objects\n",
"num_purple = len([object for object in objects if object[1] == 'purple'])\n",
"answer = num_purple\u001B[0m\n",
"answer = num_purple\u001b[0m\n",
"\n",
"\u001B[1m> Finished PALChain chain.\u001B[0m\n"
"\u001b[1m> Finished PALChain chain.\u001b[0m\n"
]
},
{
@ -212,8 +214,8 @@
"text": [
"\n",
"\n",
"\u001B[1m> Entering new PALChain chain...\u001B[0m\n",
"\u001B[32;1m\u001B[1;3m# Put objects into a list to record ordering\n",
"\u001b[1m> Entering new PALChain chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n",
"objects = []\n",
"objects += [('booklet', 'blue')] * 2\n",
"objects += [('booklet', 'purple')] * 2\n",
@ -224,9 +226,9 @@
"\n",
"# Count number of purple objects\n",
"num_purple = len([object for object in objects if object[1] == 'purple'])\n",
"answer = num_purple\u001B[0m\n",
"answer = num_purple\u001b[0m\n",
"\n",
"\u001B[1m> Finished chain.\u001B[0m\n"
"\u001b[1m> Finished chain.\u001b[0m\n"
]
}
],
@ -280,7 +282,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -73,7 +73,7 @@
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)"
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)"
]
},
{
@ -175,7 +175,7 @@
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True)"
"db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)"
]
},
{
@ -230,7 +230,7 @@
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)"
"db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)"
]
},
{
@ -285,7 +285,7 @@
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, top_k=3)"
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, top_k=3)"
]
},
{
@ -407,7 +407,7 @@
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)"
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)"
]
},
{
@ -569,7 +569,7 @@
}
],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n",
"db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n",
"db_chain.run(\"What are some example tracks by Bach?\")"
]
},
@ -681,7 +681,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -0,0 +1,199 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "593f7553-7038-498e-96d4-8255e5ce34f0",
"metadata": {},
"source": [
"# Creating a custom Chain\n",
"\n",
"To implement your own custom chain you can subclass `Chain` and implement the following methods:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
"metadata": {
"tags": [],
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"from typing import Any, Dict, List, Optional\n",
"\n",
"from pydantic import Extra\n",
"\n",
"from langchain.base_language import BaseLanguageModel\n",
"from langchain.callbacks.manager import (\n",
" AsyncCallbackManagerForChainRun,\n",
" CallbackManagerForChainRun,\n",
")\n",
"from langchain.chains.base import Chain\n",
"from langchain.prompts.base import BasePromptTemplate\n",
"\n",
"\n",
"class MyCustomChain(Chain):\n",
" \"\"\"\n",
" An example of a custom chain.\n",
" \"\"\"\n",
"\n",
" prompt: BasePromptTemplate\n",
" \"\"\"Prompt object to use.\"\"\"\n",
" llm: BaseLanguageModel\n",
" output_key: str = \"text\" #: :meta private:\n",
"\n",
" class Config:\n",
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
"\n",
" extra = Extra.forbid\n",
" arbitrary_types_allowed = True\n",
"\n",
" @property\n",
" def input_keys(self) -> List[str]:\n",
" \"\"\"Will be whatever keys the prompt expects.\n",
"\n",
" :meta private:\n",
" \"\"\"\n",
" return self.prompt.input_variables\n",
"\n",
" @property\n",
" def output_keys(self) -> List[str]:\n",
" \"\"\"Will always return text key.\n",
"\n",
" :meta private:\n",
" \"\"\"\n",
" return [self.output_key]\n",
"\n",
" def _call(\n",
" self,\n",
" inputs: Dict[str, Any],\n",
" run_manager: Optional[CallbackManagerForChainRun] = None,\n",
" ) -> Dict[str, str]:\n",
" # Your custom chain logic goes here\n",
" # This is just an example that mimics LLMChain\n",
" prompt_value = self.prompt.format_prompt(**inputs)\n",
" \n",
" # Whenever you call a language model, or another chain, you should pass\n",
" # a callback manager to it. This allows the inner run to be tracked by\n",
" # any callbacks that are registered on the outer run.\n",
" # You can always obtain a callback manager for this by calling\n",
" # `run_manager.get_child()` as shown below.\n",
" response = self.llm.generate_prompt(\n",
" [prompt_value],\n",
" callbacks=run_manager.get_child() if run_manager else None\n",
" )\n",
"\n",
" # If you want to log something about this run, you can do so by calling\n",
" # methods on the `run_manager`, as shown below. This will trigger any\n",
" # callbacks that are registered for that event.\n",
" if run_manager:\n",
" run_manager.on_text(\"Log something about this run\")\n",
" \n",
" return {self.output_key: response.generations[0][0].text}\n",
"\n",
" async def _acall(\n",
" self,\n",
" inputs: Dict[str, Any],\n",
" run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
" ) -> Dict[str, str]:\n",
" # Your custom chain logic goes here\n",
" # This is just an example that mimics LLMChain\n",
" prompt_value = self.prompt.format_prompt(**inputs)\n",
" \n",
" # Whenever you call a language model, or another chain, you should pass\n",
" # a callback manager to it. This allows the inner run to be tracked by\n",
" # any callbacks that are registered on the outer run.\n",
" # You can always obtain a callback manager for this by calling\n",
" # `run_manager.get_child()` as shown below.\n",
" response = await self.llm.agenerate_prompt(\n",
" [prompt_value],\n",
" callbacks=run_manager.get_child() if run_manager else None\n",
" )\n",
"\n",
" # If you want to log something about this run, you can do so by calling\n",
" # methods on the `run_manager`, as shown below. This will trigger any\n",
" # callbacks that are registered for that event.\n",
" if run_manager:\n",
" await run_manager.on_text(\"Log something about this run\")\n",
" \n",
" return {self.output_key: response.generations[0][0].text}\n",
"\n",
" @property\n",
" def _chain_type(self) -> str:\n",
" return \"my_custom_chain\"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "18361f89",
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
"Log something about this run\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
"from langchain.chat_models.openai import ChatOpenAI\n",
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"\n",
"chain = MyCustomChain(\n",
" prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n",
" llm=ChatOpenAI()\n",
")\n",
"\n",
"chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -589,7 +589,6 @@
"outputs": [],
"source": [
"from langchain.chains.llm import LLMChain\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
"from langchain.chains.question_answering import load_qa_chain\n",
@ -597,7 +596,7 @@
"# Construct a ConversationalRetrievalChain with a streaming llm for combine docs\n",
"# and a separate, non-streaming llm for question generation\n",
"llm = OpenAI(temperature=0)\n",
"streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"streaming_llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
"\n",
"question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n",
"doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n",

View File

@ -80,9 +80,8 @@
}
],
"source": [
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])"
]
},

View File

@ -373,9 +373,8 @@
}
],
"source": [
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])\n"
]
},

View File

@ -22,18 +22,20 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 6,
"id": "a65696a0",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.base import LLM\n",
"from typing import Optional, List, Mapping, Any"
"from typing import Any, List, Mapping, Optional\n",
"\n",
"from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
"from langchain.llms.base import LLM"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"id": "d5ceff02",
"metadata": {},
"outputs": [],
@ -46,7 +48,12 @@
" def _llm_type(self) -> str:\n",
" return \"custom\"\n",
" \n",
" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n",
" def _call(\n",
" self,\n",
" prompt: str,\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" ) -> str:\n",
" if stop is not None:\n",
" raise ValueError(\"stop kwargs are not permitted.\")\n",
" return prompt[:self.n]\n",
@ -67,7 +74,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"id": "10e5ece6",
"metadata": {},
"outputs": [],
@ -77,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 9,
"id": "8cd49199",
"metadata": {},
"outputs": [
@ -87,7 +94,7 @@
"'This is a '"
]
},
"execution_count": 4,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -106,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 10,
"id": "9c33fa19",
"metadata": {},
"outputs": [

View File

@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07",
"metadata": {
"tags": []
@ -21,14 +21,13 @@
"source": [
"from langchain.llms import OpenAI, Anthropic\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"from langchain.schema import HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "77f60a4b-f786-41f2-972e-e5bb8a48dcd5",
"metadata": {
"tags": []
@ -79,7 +78,7 @@
}
],
"source": [
"llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
"resp = llm(\"Write me a song about sparkling water.\")"
]
},
@ -95,7 +94,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "a35373f1-9ee6-4753-a343-5aee749b8527",
"metadata": {
"tags": []
@ -136,7 +135,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "22665f16-e05b-473c-a4bd-ad75744ea024",
"metadata": {
"tags": []
@ -191,7 +190,7 @@
}
],
"source": [
"chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
"resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])"
]
},
@ -205,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "eadae4ba-9f21-4ec8-845d-dd43b0edc2dc",
"metadata": {
"tags": []
@ -245,7 +244,7 @@
}
],
"source": [
"llm = Anthropic(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n",
"llm = Anthropic(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n",
"llm(\"Write me a song about sparkling water.\")"
]
}

View File

@ -40,7 +40,6 @@
"source": [
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.llms import GPT4All\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
]
},
@ -124,9 +123,9 @@
"outputs": [],
"source": [
"# Callbacks support token-wise streaming\n",
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
"callbacks = [StreamingStdOutCallbackHandler()]\n",
"# Verbose is required to pass to the callback manager\n",
"llm = GPT4All(model=local_path, callback_manager=callback_manager, verbose=True)"
"llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)"
]
},
{

View File

@ -6,7 +6,6 @@ First, you should install tracing and set up your environment properly.
You can use either a locally hosted version of this (uses Docker) or a cloud hosted version (in closed alpha).
If you're interested in using the hosted platform, please fill out the form [here](https://forms.gle/tRCEMSeopZf6TE3b6).
- [Locally Hosted Setup](./tracing/local_installation.md)
- [Cloud Hosted Setup](./tracing/hosted_installation.md)
@ -46,11 +45,12 @@ For example, here is the lowest level trace with the exact inputs/outputs to the
![](tracing/explore_llm.png)
## Changing Sessions
1. To initially record traces to a session other than `"default"`, you can set the `LANGCHAIN_SESSION` environment variable to the name of the session you want to record to:
```python
import os
os.environ["LANGCHAIN_HANDLER"] = "langchain"
os.environ["LANGCHAIN_TRACING"] = "true"
os.environ["LANGCHAIN_SESSION"] = "my_session" # Make sure this session actually exists. You can create a new session in the UI.
```

View File

@ -5,7 +5,14 @@
"id": "5371a9bb",
"metadata": {},
"source": [
"# Tracing Walkthrough"
"# Tracing Walkthrough\n",
"\n",
"There are two recommended ways to trace your LangChains:\n",
"\n",
"1. Setting the `LANGCHAIN_TRACING` environment variable to \"true\".\n",
"1. Using a context manager with tracing_enabled() to trace a particular block of code.\n",
"\n",
"**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager."
]
},
{
@ -18,24 +25,22 @@
"outputs": [],
"source": [
"import os\n",
"os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n",
"\n",
"## Uncomment this if using hosted setup.\n",
"os.environ[\"LANGCHAIN_TRACING\"] = \"true\"\n",
"\n",
"## Uncomment below if using hosted setup.\n",
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchain-api-gateway-57eoxz8z.uc.gateway.dev\" \n",
"\n",
"## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n",
"\n",
"## Uncomment below if you want traces to be recorded to \"my_session\" instead of \"default\".\n",
"# os.environ[\"LANGCHAIN_SESSION\"] = \"my_session\" \n",
"\n",
"## Better to set this environment variable in the terminal\n",
"## Uncomment this if using hosted version. Replace \"my_api_key\" with your actual API Key.\n",
"\n",
"## Uncomment below if using hosted version. Replace \"my_api_key\" with your actual API Key.\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = \"my_api_key\" \n",
"\n",
"import langchain\n",
"from langchain.agents import Tool, initialize_agent, load_tools\n",
"from langchain.agents import AgentType\n",
"from langchain.callbacks import tracing_enabled\n",
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.llms import OpenAI"
]
@ -73,8 +78,7 @@
"\u001b[32;1m\u001b[1;3m I need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 2^.123243\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
"Final Answer: 1.0891804557407723\u001b[0m\n",
"\n",
@ -104,7 +108,9 @@
"cell_type": "code",
"execution_count": 4,
"id": "4829eb1d",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
@ -113,52 +119,11 @@
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mQuestion: What is 2 raised to .123243 power?\n",
"Thought: I need a calculator to solve this problem.\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"calculator\",\n",
" \"action_input\": \"2^0.123243\"\n",
"}\n",
"```\n",
"\u001b[0m\n",
"Observation: calculator is not a valid tool, try another one.\n",
"\u001b[32;1m\u001b[1;3mI made a mistake, I need to use the correct tool for this question.\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"calculator\",\n",
" \"action_input\": \"2^0.123243\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: calculator is not a valid tool, try another one.\n",
"\u001b[32;1m\u001b[1;3mI made a mistake, the tool name is actually \"calc\" instead of \"calculator\".\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"calc\",\n",
" \"action_input\": \"2^0.123243\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: calc is not a valid tool, try another one.\n",
"\u001b[32;1m\u001b[1;3mI made another mistake, the tool name is actually \"Calculator\" instead of \"calc\".\n",
"Action:\n",
"```\n",
"{\n",
" \"action\": \"Calculator\",\n",
" \"action_input\": \"2^0.123243\"\n",
"}\n",
"```\n",
"\n",
"\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\n",
"\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mThe final answer is 1.0891804557407723.\n",
"\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 2 ^ .123243\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n",
"Final Answer: 1.0891804557407723\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
@ -186,8 +151,182 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "76abfd82",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 2 ^ .123243\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n",
"Final Answer: 1.0891804557407723\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 5 ^ .123243\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2193914912400514\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n",
"Final Answer: 1.2193914912400514\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
}
],
"source": [
"# Both of the agent runs will be traced because the environment variable is set\n",
"agent.run(\"What is 2 raised to .123243 power?\")\n",
"with tracing_enabled() as session:\n",
" agent.run(\"What is 5 raised to .123243 power?\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fe833c33-033f-4806-be0c-cc3d147db13d",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 5 ^ .123243\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2193914912400514\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n",
"Final Answer: 1.2193914912400514\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 2 ^ .123243\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n",
"Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n",
"Final Answer: 1.0891804557407723\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'1.0891804557407723'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Now, we unset the environment variable and use a context manager.\n",
"if \"LANGCHAIN_TRACING\" in os.environ:\n",
" del os.environ[\"LANGCHAIN_TRACING\"]\n",
"\n",
"# here, we are writing traces to \"my_test_session\"\n",
"with tracing_enabled(\"my_session\") as session:\n",
" assert session\n",
" agent.run(\"What is 5 raised to .123243 power?\") # this should be traced\n",
"\n",
"agent.run(\"What is 2 raised to .123243 power?\") # this should not be traced"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "b34105a4-be8e-46e4-8abe-01adba3ba727",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\n",
"\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 3^0.123\u001b[0m\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n",
"Action: Calculator\n",
"Action Input: 2^0.123\u001b[0m\u001b[32;1m\u001b[1;3mAny number raised to the power of 0 is 1, but I'm not sure about a decimal power.\n",
"Action: Calculator\n",
"Action Input: 1^.123\u001b[0m\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.1446847956963533\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0889970153361064\u001b[0m\n",
"Thought:\n",
"Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0\u001b[0m\n",
"Thought:\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'1.0'"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The context manager is concurrency safe:\n",
"import asyncio \n",
"if \"LANGCHAIN_TRACING\" in os.environ:\n",
" del os.environ[\"LANGCHAIN_TRACING\"]\n",
" \n",
"questions = [f\"What is {i} raised to .123 power?\" for i in range(1,4)]\n",
"\n",
"# start a background task\n",
"task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced\n",
"with tracing_enabled() as session:\n",
" assert session\n",
" tasks = [agent.arun(q) for q in questions[1:3]] # these should be traced\n",
" await asyncio.gather(*tasks)\n",
"\n",
"await task"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e46c85b-2ac0-4661-abed-9c2bf3036820",
"metadata": {},
"outputs": [],
"source": []
@ -209,7 +348,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@ -5,11 +5,6 @@ from typing import Optional
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.callbacks import (
set_default_callback_manager,
set_handler,
set_tracing_callback_manager,
)
from langchain.chains import (
ConversationChain,
LLMBashChain,
@ -67,7 +62,6 @@ del metadata # optional, avoids polluting the results of dir(__package__)
verbose: bool = False
llm_cache: Optional[BaseCache] = None
set_default_callback_manager()
# For backwards compatibility
SerpAPIChain = SerpAPIWrapper
@ -119,7 +113,5 @@ __all__ = [
"VectorDBQAWithSourcesChain",
"QAWithSourcesChain",
"PALChain",
"set_handler",
"set_tracing_callback_manager",
"LlamaCpp",
]

View File

@ -13,7 +13,13 @@ import yaml
from pydantic import BaseModel, root_validator
from langchain.agents.tools import InvalidTool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
@ -23,7 +29,6 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
)
@ -46,13 +51,17 @@ class BaseSingleActionAgent(BaseModel):
@abstractmethod
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@ -61,13 +70,17 @@ class BaseSingleActionAgent(BaseModel):
@abstractmethod
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@ -170,13 +183,17 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@ -185,13 +202,17 @@ class BaseMultiActionAgent(BaseModel):
@abstractmethod
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[List[AgentAction], AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
@ -285,38 +306,52 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = self.llm_chain.run(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
intermediate_steps=intermediate_steps,
stop=self.stop,
callbacks=callbacks,
**kwargs,
)
return self.output_parser.parse(output)
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
output = await self.llm_chain.arun(
intermediate_steps=intermediate_steps, stop=self.stop, **kwargs
intermediate_steps=intermediate_steps,
stop=self.stop,
callbacks=callbacks,
**kwargs,
)
return self.output_parser.parse(output)
@ -368,37 +403,45 @@ class Agent(BaseSingleActionAgent):
return thoughts
def plan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = self.llm_chain.predict(**full_inputs)
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
async def aplan(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date,
along with observations
callbacks: Callbacks to run.
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = await self.llm_chain.apredict(**full_inputs)
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
def get_full_inputs(
@ -636,24 +679,27 @@ class AgentExecutor(Chain):
return True
def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]:
self.callback_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
def _return(
self,
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if run_manager:
run_manager.on_agent_finish(output, color="green", verbose=self.verbose)
final_output = output.return_values
if self.return_intermediate_steps:
final_output["intermediate_steps"] = intermediate_steps
return final_output
async def _areturn(
self, output: AgentFinish, intermediate_steps: list
self,
output: AgentFinish,
intermediate_steps: list,
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.callback_manager.is_async:
await self.callback_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
else:
self.callback_manager.on_agent_finish(
if run_manager:
await run_manager.on_agent_finish(
output, color="green", verbose=self.verbose
)
final_output = output.return_values
@ -667,13 +713,18 @@ class AgentExecutor(Chain):
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
# Call the LLM to see what to do.
output = self.agent.plan(intermediate_steps, **inputs)
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
@ -684,9 +735,8 @@ class AgentExecutor(Chain):
actions = output
result = []
for agent_action in actions:
self.callback_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
@ -700,6 +750,7 @@ class AgentExecutor(Chain):
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
@ -708,6 +759,7 @@ class AgentExecutor(Chain):
agent_action.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
result.append((agent_action, observation))
@ -719,13 +771,18 @@ class AgentExecutor(Chain):
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
# Call the LLM to see what to do.
output = await self.agent.aplan(intermediate_steps, **inputs)
output = await self.agent.aplan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
return output
@ -738,12 +795,8 @@ class AgentExecutor(Chain):
async def _aperform_agent_action(
agent_action: AgentAction,
) -> Tuple[AgentAction, str]:
if self.callback_manager.is_async:
await self.callback_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
else:
self.callback_manager.on_agent_action(
if run_manager:
await run_manager.on_agent_action(
agent_action, verbose=self.verbose, color="green"
)
# Otherwise we lookup the tool
@ -759,6 +812,7 @@ class AgentExecutor(Chain):
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
@ -767,6 +821,7 @@ class AgentExecutor(Chain):
agent_action.tool,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
return agent_action, observation
@ -778,7 +833,11 @@ class AgentExecutor(Chain):
return list(result)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
@ -794,10 +853,16 @@ class AgentExecutor(Chain):
# We now enter the agent loop (until it returns something).
while self._should_continue(iterations, time_elapsed):
next_step_output = self._take_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager=run_manager,
)
if isinstance(next_step_output, AgentFinish):
return self._return(next_step_output, intermediate_steps)
return self._return(
next_step_output, intermediate_steps, run_manager=run_manager
)
intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1:
@ -813,7 +878,11 @@ class AgentExecutor(Chain):
)
return self._return(output, intermediate_steps)
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
@ -831,7 +900,11 @@ class AgentExecutor(Chain):
try:
while self._should_continue(iterations, time_elapsed):
next_step_output = await self._atake_next_step(
name_to_tool_map, color_mapping, inputs, intermediate_steps
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager=run_manager,
)
if isinstance(next_step_output, AgentFinish):
return await self._areturn(next_step_output, intermediate_steps)
@ -849,7 +922,9 @@ class AgentExecutor(Chain):
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return await self._areturn(output, intermediate_steps)
return await self._areturn(
output, intermediate_steps, run_manager=run_manager
)
except TimeoutError:
# stop early when interrupted by the async timeout
output = self.agent.return_stopped_response(

View File

@ -54,7 +54,7 @@ class FileManagementToolkit(BaseToolkit):
tools: List[BaseTool] = []
for tool in allowed_tools:
tool_cls = _FILE_TOOLS[tool]
tools.append(tool_cls(root_dir=self.root_dir))
tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore
return tools

View File

@ -28,13 +28,13 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import (
from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
from langchain.tools.requests.tool import BaseRequestsTool

View File

@ -4,10 +4,10 @@ from typing import List, Optional
from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.tools import BaseTool
from langchain.tools.powerbi.prompt import QUESTION_TO_QUERY
from langchain.tools.powerbi.tool import (

View File

@ -4,7 +4,7 @@ from typing import List
from pydantic import Field
from langchain.agents.agent_toolkits.base import BaseToolkit
from langchain.schema import BaseLanguageModel
from langchain.base_language import BaseLanguageModel
from langchain.sql_database import SQLDatabase
from langchain.tools import BaseTool
from langchain.tools.sql_database.tool import (

View File

@ -5,6 +5,7 @@ from pydantic import Field
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.chat.output_parser import ChatOutputParser
from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
@ -13,7 +14,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction, BaseLanguageModel
from langchain.schema import AgentAction
from langchain.tools import BaseTool

View File

@ -9,10 +9,10 @@ from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.agent_types import AgentType
from langchain.agents.conversational.output_parser import ConvoOutputParser
from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import (
SUFFIX,
TEMPLATE_TOOL_RESPONSE,
)
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts.base import BasePromptTemplate
@ -24,7 +25,6 @@ from langchain.prompts.chat import (
from langchain.schema import (
AgentAction,
AIMessage,
BaseLanguageModel,
BaseMessage,
BaseOutputParser,
HumanMessage,

View File

@ -4,8 +4,8 @@ from typing import Any, Optional, Sequence
from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_types import AgentType
from langchain.agents.loading import AGENT_TO_CLASS, load_agent
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@ -103,8 +103,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool:
return Tool(
name="Calculator",
description="Useful for when you need to answer questions about math.",
func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run,
coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun,
func=LLMMathChain.from_llm(llm=llm).run,
coroutine=LLMMathChain.from_llm(llm=llm).arun,
)

View File

@ -10,10 +10,10 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX
from langchain.agents.tools import Tool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool

View File

@ -1,9 +1,14 @@
"""Interface for tools."""
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import BaseModel, validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.tools.base import BaseTool, StructuredTool
@ -44,14 +49,44 @@ class Tool(BaseTool):
)
return tuple(all_args), {}
def _run(self, *args: Any, **kwargs: Any) -> Any:
def _run(
self,
*args: Any,
run_manager: Optional[CallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool."""
return self.func(*args, **kwargs)
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
async def _arun(
self,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
**kwargs: Any,
) -> Any:
"""Use the tool asynchronously."""
if self.coroutine:
return await self.coroutine(*args, **kwargs)
new_argument_supported = signature(self.coroutine).parameters.get(
"callbacks"
)
return (
await self.coroutine(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else await self.coroutine(*args, **kwargs)
)
raise NotImplementedError("Tool does not support async")
# TODO: this is for backwards compatibility, remove in future
@ -70,11 +105,17 @@ class InvalidTool(BaseTool):
name = "invalid_tool"
description = "Called when tool name is invalid."
def _run(self, tool_name: str) -> str:
def _run(
self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Use the tool."""
return f"{tool_name} is not a valid tool, try another one."
async def _arun(self, tool_name: str) -> str:
async def _arun(
self,
tool_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Use the tool asynchronously."""
return f"{tool_name} is not a valid tool, try another one."

View File

@ -0,0 +1,60 @@
"""Base class for all language models."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
def _get_num_tokens_default_method(text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install transformers`."
)
# create a GPT-2 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
class BaseLanguageModel(BaseModel, ABC):
@abstractmethod
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
return _get_num_tokens_default_method(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])

View File

@ -1,80 +1,19 @@
"""Callback handlers that allow listening to events in LangChain."""
import os
from contextlib import contextmanager
from typing import Generator, Optional
from langchain.callbacks.aim_callback import AimCallbackHandler
from langchain.callbacks.base import (
AsyncCallbackManager,
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
from langchain.callbacks.manager import (
get_openai_callback,
tracing_enabled,
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.shared import SharedCallbackManager
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.tracers import SharedLangChainTracer
from langchain.callbacks.wandb_callback import WandbCallbackHandler
def get_callback_manager() -> BaseCallbackManager:
"""Return the shared callback manager."""
return SharedCallbackManager()
def set_handler(handler: BaseCallbackHandler) -> None:
"""Set handler."""
callback = get_callback_manager()
callback.set_handler(handler)
def set_default_callback_manager() -> None:
"""Set default callback manager."""
default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout")
if default_handler == "stdout":
set_handler(StdOutCallbackHandler())
elif default_handler == "langchain":
session = os.environ.get("LANGCHAIN_SESSION")
set_tracing_callback_manager(session)
else:
raise ValueError(
f"LANGCHAIN_HANDLER should be one of `stdout` "
f"or `langchain`, got {default_handler}"
)
def set_tracing_callback_manager(session_name: Optional[str] = None) -> None:
"""Set tracing callback manager."""
handler = SharedLangChainTracer()
callback = get_callback_manager()
callback.set_handlers([handler, StdOutCallbackHandler()])
if session_name is None:
handler.load_default_session()
else:
try:
handler.load_session(session_name)
except Exception:
raise ValueError(f"session {session_name} not found")
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get OpenAI callback handler in a context manager."""
handler = OpenAICallbackHandler()
manager = get_callback_manager()
manager.add_handler(handler)
yield handler
manager.remove_handler(handler)
__all__ = [
"CallbackManager",
"AsyncCallbackManager",
"OpenAICallbackHandler",
"SharedCallbackManager",
"StdOutCallbackHandler",
"AimCallbackHandler",
"WandbCallbackHandler",
@ -82,8 +21,5 @@ __all__ = [
"CometCallbackHandler",
"AsyncIteratorCallbackHandler",
"get_openai_callback",
"set_tracing_callback_manager",
"set_default_callback_manager",
"set_handler",
"get_callback_manager",
"tracing_enabled",
]

View File

@ -1,19 +1,174 @@
"""Base callback handler that can be used to handle callbacks from langchain."""
import asyncio
import functools
from abc import ABC, abstractmethod
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
import copy
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain.schema import AgentAction, AgentFinish, LLMResult
class BaseCallbackHandler(ABC):
"""Base callback handler that can be used to handle callbacks from langchain."""
class LLMManagerMixin:
"""Mixin for LLM callbacks."""
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return False
def on_llm_new_token(
self,
token: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Only available when streaming is enabled."""
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM ends running."""
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM errors."""
class ChainManagerMixin:
"""Mixin for chain callbacks."""
def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain ends running."""
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain errors."""
def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent action."""
def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on agent end."""
class ToolManagerMixin:
"""Mixin for tool callbacks."""
def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool ends running."""
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool errors."""
class CallbackManagerMixin:
"""Mixin for callback manager."""
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running."""
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
class RunManagerMixin:
"""Mixin for run manager."""
def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on arbitrary text."""
class BaseCallbackHandler(
LLMManagerMixin,
ChainManagerMixin,
ToolManagerMixin,
CallbackManagerMixin,
RunManagerMixin,
):
"""Base callback handler that can be used to handle callbacks from langchain."""
@property
def ignore_llm(self) -> bool:
@ -30,480 +185,197 @@ class BaseCallbackHandler(ABC):
"""Whether to ignore agent callbacks."""
return False
@abstractmethod
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
@abstractmethod
def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
async def on_llm_new_token(
self,
token: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
@abstractmethod
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
async def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when LLM ends running."""
@abstractmethod
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
@abstractmethod
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> Any:
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
@abstractmethod
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
async def on_chain_end(
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
@abstractmethod
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when chain errors."""
@abstractmethod
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
@abstractmethod
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
async def on_tool_end(
self,
output: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
@abstractmethod
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> Any:
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when tool errors."""
@abstractmethod
def on_text(self, text: str, **kwargs: Any) -> Any:
async def on_text(
self,
text: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on arbitrary text."""
@abstractmethod
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
async def on_agent_action(
self,
action: AgentAction,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on agent action."""
@abstractmethod
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
async def on_agent_finish(
self,
finish: AgentFinish,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on agent end."""
class BaseCallbackManager(BaseCallbackHandler, ABC):
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that can be used to handle callbacks from LangChain."""
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None,
) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
self.inheritable_handlers: List[BaseCallbackHandler] = (
inheritable_handlers or []
)
self.parent_run_id: Optional[UUID] = parent_run_id
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""
return False
@abstractmethod
def add_handler(self, callback: BaseCallbackHandler) -> None:
def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
if inherit:
self.inheritable_handlers.append(handler)
@abstractmethod
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
self.inheritable_handlers.remove(handler)
def set_handler(self, handler: BaseCallbackHandler) -> None:
def set_handlers(
self, handlers: List[BaseCallbackHandler], inherit: bool = True
) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = []
self.inheritable_handlers = []
for handler in handlers:
self.add_handler(handler, inherit=inherit)
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager."""
self.set_handlers([handler])
self.set_handlers([handler], inherit=inherit)
@abstractmethod
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_start(serialized, prompts, **kwargs)
def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM generates a new token."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_new_token(token, **kwargs)
def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM ends running."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_end(response)
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM errors."""
for handler in self.handlers:
if not handler.ignore_llm:
if verbose or handler.always_verbose:
handler.on_llm_error(error)
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain starts running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None:
"""Run when chain ends running."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_end(outputs)
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain errors."""
for handler in self.handlers:
if not handler.ignore_chain:
if verbose or handler.always_verbose:
handler.on_chain_error(error)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool starts running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when tool ends running."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_end(output, **kwargs)
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool errors."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_tool_error(error)
def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run on additional input from chains and agents."""
for handler in self.handlers:
if verbose or handler.always_verbose:
handler.on_text(text, **kwargs)
def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent end."""
for handler in self.handlers:
if not handler.ignore_agent:
if verbose or handler.always_verbose:
handler.on_agent_finish(finish, **kwargs)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
async def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
async def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
async def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
async def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
async def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None:
"""Run on agent action."""
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
async def _handle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
if verbose or handler.always_verbose:
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
def __copy__(self) -> "BaseCallbackManager":
return self.__class__(
self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id
)
class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that can be used to handle callbacks from LangChain."""
@property
def is_async(self) -> bool:
"""Return whether the handler is async."""
return True
def __init__(self, handlers: List[BaseCallbackHandler]) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
async def _handle_event(
self,
event_name: str,
ignore_condition_name: Optional[str],
verbose: bool,
*args: Any,
**kwargs: Any
) -> None:
"""Generic event handler for AsyncCallbackManager."""
await asyncio.gather(
*(
_handle_event_for_handler(
handler, event_name, ignore_condition_name, verbose, *args, **kwargs
def __deepcopy__(self, memo: dict) -> "BaseCallbackManager":
return self.__class__(
[copy.deepcopy(handler, memo) for handler in self.handlers],
[copy.deepcopy(handler, memo) for handler in self.inheritable_handlers],
self.parent_run_id,
)
for handler in self.handlers
)
)
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM starts running."""
await self._handle_event(
"on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs
)
async def on_llm_new_token(
self, token: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
await self._handle_event(
"on_llm_new_token", "ignore_llm", verbose, token, **kwargs
)
async def on_llm_end(
self, response: LLMResult, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when LLM ends running."""
await self._handle_event(
"on_llm_end", "ignore_llm", verbose, response, **kwargs
)
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when LLM errors."""
await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs)
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain starts running."""
await self._handle_event(
"on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs
)
async def on_chain_end(
self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any
) -> None:
"""Run when chain ends running."""
await self._handle_event(
"on_chain_end", "ignore_chain", verbose, outputs, **kwargs
)
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when chain errors."""
await self._handle_event(
"on_chain_error", "ignore_chain", verbose, error, **kwargs
)
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool starts running."""
await self._handle_event(
"on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs
)
async def on_tool_end(
self, output: str, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when tool ends running."""
await self._handle_event(
"on_tool_end", "ignore_agent", verbose, output, **kwargs
)
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
verbose: bool = False,
**kwargs: Any
) -> None:
"""Run when tool errors."""
await self._handle_event(
"on_tool_error", "ignore_agent", verbose, error, **kwargs
)
async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None:
"""Run when text is printed."""
await self._handle_event("on_text", None, verbose, text, **kwargs)
async def on_agent_action(
self, action: AgentAction, verbose: bool = False, **kwargs: Any
) -> None:
"""Run on agent action."""
await self._handle_event(
"on_agent_action", "ignore_agent", verbose, action, **kwargs
)
async def on_agent_finish(
self, finish: AgentFinish, verbose: bool = False, **kwargs: Any
) -> None:
"""Run when agent finishes."""
await self._handle_event(
"on_agent_finish", "ignore_agent", verbose, finish, **kwargs
)
def add_handler(self, handler: BaseCallbackHandler) -> None:
"""Add a handler to the callback manager."""
self.handlers.append(handler)
def remove_handler(self, handler: BaseCallbackHandler) -> None:
"""Remove a handler from the callback manager."""
self.handlers.remove(handler)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
self.handlers = handlers

View File

@ -0,0 +1,736 @@
from __future__ import annotations
import asyncio
import copy
import functools
import os
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union
from uuid import UUID, uuid4
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
ChainManagerMixin,
LLMManagerMixin,
RunManagerMixin,
ToolManagerMixin,
)
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.schema import AgentAction, AgentFinish, LLMResult
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar(
"tracing_callback", default=None
)
@contextmanager
def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
"""Get OpenAI callback handler in a context manager."""
cb = OpenAICallbackHandler()
openai_callback_var.set(cb)
yield cb
openai_callback_var.set(None)
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSession, None, None]:
"""Get OpenAI callback handler in a context manager."""
cb = LangChainTracer()
session = cb.load_session(session_name)
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
def _handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
for handler in handlers:
try:
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
getattr(handler, event_name)(*args, **kwargs)
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
async def _ahandle_event_for_handler(
handler: BaseCallbackHandler,
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
try:
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
event = getattr(handler, event_name)
if asyncio.iscoroutinefunction(event):
await event(*args, **kwargs)
else:
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
)
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
async def _ahandle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for AsyncCallbackManager."""
await asyncio.gather(
*(
_ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
)
for handler in handlers
)
)
BRM = TypeVar("BRM", bound="BaseRunManager")
class BaseRunManager(RunManagerMixin):
"""Base class for run manager (a bound callback manager)."""
def __init__(
self,
run_id: UUID,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler],
parent_run_id: Optional[UUID] = None,
) -> None:
"""Initialize run manager."""
self.run_id = run_id
self.handlers = handlers
self.inheritable_handlers = inheritable_handlers
self.parent_run_id = parent_run_id
@classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM:
"""Return a manager that doesn't perform any operations."""
return cls(uuid4(), [], [])
class RunManager(BaseRunManager):
"""Sync Run Manager."""
def on_text(
self,
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received."""
_handle_event(
self.handlers,
"on_text",
None,
text,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncRunManager(BaseRunManager):
"""Async Run Manager."""
async def on_text(
self,
text: str,
**kwargs: Any,
) -> Any:
"""Run when text is received."""
await _ahandle_event(
self.handlers,
"on_text",
None,
text,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
"""Callback manager for LLM run."""
def on_llm_new_token(
self,
token: str,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token."""
_handle_event(
self.handlers,
"on_llm_new_token",
"ignore_llm",
token=token,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
_handle_event(
self.handlers,
"on_llm_end",
"ignore_llm",
response,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
_handle_event(
self.handlers,
"on_llm_error",
"ignore_llm",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
"""Async callback manager for LLM run."""
async def on_llm_new_token(
self,
token: str,
**kwargs: Any,
) -> None:
"""Run when LLM generates a new token."""
await _ahandle_event(
self.handlers,
"on_llm_new_token",
"ignore_llm",
token,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
await _ahandle_event(
self.handlers,
"on_llm_end",
"ignore_llm",
response,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when LLM errors."""
await _ahandle_event(
self.handlers,
"on_llm_error",
"ignore_llm",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
"""Callback manager for chain run."""
def get_child(self) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
_handle_event(
self.handlers,
"on_chain_end",
"ignore_chain",
outputs,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when chain errors."""
_handle_event(
self.handlers,
"on_chain_error",
"ignore_chain",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received."""
_handle_event(
self.handlers,
"on_agent_action",
"ignore_agent",
action,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received."""
_handle_event(
self.handlers,
"on_agent_finish",
"ignore_agent",
finish,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
"""Async callback manager for chain run."""
def get_child(self) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
await _ahandle_event(
self.handlers,
"on_chain_end",
"ignore_chain",
outputs,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when chain errors."""
await _ahandle_event(
self.handlers,
"on_chain_error",
"ignore_chain",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received."""
await _ahandle_event(
self.handlers,
"on_agent_action",
"ignore_agent",
action,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received."""
await _ahandle_event(
self.handlers,
"on_agent_finish",
"ignore_agent",
finish,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
"""Callback manager for tool run."""
def get_child(self) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
def on_tool_end(
self,
output: str,
**kwargs: Any,
) -> None:
"""Run when tool ends running."""
_handle_event(
self.handlers,
"on_tool_end",
"ignore_agent",
output,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when tool errors."""
_handle_event(
self.handlers,
"on_tool_error",
"ignore_agent",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
"""Async callback manager for tool run."""
def get_child(self) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
return manager
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
await _ahandle_event(
self.handlers,
"on_tool_end",
"ignore_agent",
output,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
**kwargs: Any,
) -> None:
"""Run when tool errors."""
await _ahandle_event(
self.handlers,
"on_tool_error",
"ignore_agent",
error,
run_id=self.run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
class CallbackManager(BaseCallbackManager):
"""Callback manager that can be used to handle callbacks from langchain."""
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForChainRun:
"""Run when chain starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chain_start",
"ignore_chain",
serialized,
inputs,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return CallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForToolRun:
"""Run when tool starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_tool_start",
"ignore_agent",
serialized,
input_str,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return CallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
) -> CallbackManager:
"""Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that can be used to handle callbacks from LangChain."""
@property
def is_async(self) -> bool:
"""Return whether the handler is async."""
return True
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForLLMRun:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForChainRun:
"""Run when chain starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_chain_start",
"ignore_chain",
serialized,
inputs,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
run_id: Optional[UUID] = None,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForToolRun:
"""Run when tool starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_tool_start",
"ignore_agent",
serialized,
input_str,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
@classmethod
def configure(
cls,
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
) -> AsyncCallbackManager:
"""Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
def _configure(
callback_manager_cls: Type[T],
inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None,
verbose: bool = False,
) -> T:
"""Configure the callback manager."""
callback_manager = callback_manager_cls([])
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
inheritable_callbacks_ = inheritable_callbacks or []
callback_manager = callback_manager_cls(
handlers=inheritable_callbacks_,
inheritable_handlers=inheritable_callbacks_,
)
else:
callback_manager = callback_manager_cls(
handlers=inheritable_callbacks.handlers,
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
parent_run_id=inheritable_callbacks.parent_run_id,
)
callback_manager = copy.deepcopy(callback_manager)
local_handlers_ = (
local_callbacks
if isinstance(local_callbacks, list)
else (local_callbacks.handlers if local_callbacks else [])
)
for handler in local_handlers_:
callback_manager.add_handler(copy.deepcopy(handler), False)
tracer = tracing_callback_var.get()
open_ai = openai_callback_var.get()
tracing_enabled_ = (
os.environ.get("LANGCHAIN_TRACING") is not None
or tracer is not None
or os.environ.get("LANGCHAIN_HANDLER") is not None
)
tracer_session = os.environ.get("LANGCHAIN_SESSION")
if tracer_session is None:
tracer_session = "default"
if verbose or tracing_enabled_ or open_ai is not None:
if verbose and not any(
isinstance(handler, StdOutCallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(StdOutCallbackHandler(), False)
if tracing_enabled_ and not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
if tracer:
callback_manager.add_handler(tracer, True)
else:
handler = LangChainTracer()
handler.load_session(tracer_session)
callback_manager.add_handler(handler, True)
if open_ai is not None and not any(
isinstance(handler, OpenAICallbackHandler)
for handler in callback_manager.handlers
):
callback_manager.add_handler(open_ai, True)
return callback_manager

View File

@ -148,16 +148,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
"""Do nothing."""
pass
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run when agent ends."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
@ -167,3 +157,11 @@ class OpenAICallbackHandler(BaseCallbackHandler):
) -> None:
"""Run on agent end."""
pass
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self
def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler":
"""Return a deep copy of the callback handler."""
return self

View File

@ -1,127 +0,0 @@
"""A shared CallbackManager."""
import threading
from typing import Any, Dict, List, Union
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
CallbackManager,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
class Singleton:
"""A thread-safe singleton class that can be inherited from."""
_instance = None
_lock = threading.Lock()
def __new__(cls) -> Any:
"""Create a new shared instance of the class."""
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
class SharedCallbackManager(Singleton, BaseCallbackManager):
"""A thread-safe singleton CallbackManager."""
_callback_manager: CallbackManager = CallbackManager(handlers=[])
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Run when LLM starts running."""
with self._lock:
self._callback_manager.on_llm_start(serialized, prompts, **kwargs)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
with self._lock:
self._callback_manager.on_llm_end(response, **kwargs)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run when LLM generates a new token."""
with self._lock:
self._callback_manager.on_llm_new_token(token, **kwargs)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors."""
with self._lock:
self._callback_manager.on_llm_error(error, **kwargs)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Run when chain starts running."""
with self._lock:
self._callback_manager.on_chain_start(serialized, inputs, **kwargs)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
with self._lock:
self._callback_manager.on_chain_end(outputs, **kwargs)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors."""
with self._lock:
self._callback_manager.on_chain_error(error, **kwargs)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
"""Run when tool starts running."""
with self._lock:
self._callback_manager.on_tool_start(serialized, input_str, **kwargs)
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
with self._lock:
self._callback_manager.on_agent_action(action, **kwargs)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
with self._lock:
self._callback_manager.on_tool_end(output, **kwargs)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors."""
with self._lock:
self._callback_manager.on_tool_error(error, **kwargs)
def on_text(self, text: str, **kwargs: Any) -> None:
"""Run on arbitrary text."""
with self._lock:
self._callback_manager.on_text(text, **kwargs)
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Run on agent end."""
with self._lock:
self._callback_manager.on_agent_finish(finish, **kwargs)
def add_handler(self, callback: BaseCallbackHandler) -> None:
"""Add a callback to the callback manager."""
with self._lock:
self._callback_manager.add_handler(callback)
def remove_handler(self, callback: BaseCallbackHandler) -> None:
"""Remove a callback from the callback manager."""
with self._lock:
self._callback_manager.remove_handler(callback)
def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
"""Set handlers as the only handlers on the callback manager."""
with self._lock:
self._callback_manager.handlers = handlers

View File

@ -91,7 +91,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
**kwargs: Any,
) -> None:
"""Run when agent ends."""
print_text(text, color=color if color else self.color, end=end)

View File

@ -1,12 +1,5 @@
"""Tracers that record execution of LangChain runs."""
from langchain.callbacks.tracers.base import SharedTracer, Tracer
from langchain.callbacks.tracers.langchain import BaseLangChainTracer
from langchain.callbacks.tracers.langchain import LangChainTracer
class SharedLangChainTracer(SharedTracer, BaseLangChainTracer):
"""Shared tracer that records LangChain execution to LangChain endpoint."""
class LangChainTracer(Tracer, BaseLangChainTracer):
"""Tracer that records LangChain execution to LangChain endpoint."""
__all__ = ["LangChainTracer"]

View File

@ -1,14 +1,12 @@
"""Base interfaces for tracing runs."""
from __future__ import annotations
import threading
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.shared import Singleton
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
@ -16,7 +14,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSession,
TracerSessionCreate,
)
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import LLMResult
class TracerException(Exception):
@ -26,13 +24,25 @@ class TracerException(Exception):
class BaseTracer(BaseCallbackHandler, ABC):
"""Base interface for tracers."""
@abstractmethod
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {}
self.session: Optional[TracerSession] = None
@staticmethod
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
if isinstance(child_run, LLMRun):
parent_run.child_llm_runs.append(child_run)
elif isinstance(child_run, ChainRun):
parent_run.child_chain_runs.append(child_run)
elif isinstance(child_run, ToolRun):
parent_run.child_tool_runs.append(child_run)
else:
raise TracerException(f"Invalid run type: {type(child_run)}")
@abstractmethod
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
@ -42,15 +52,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _persist_session(self, session: TracerSessionCreate) -> TracerSession:
"""Persist a tracing session."""
@abstractmethod
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession:
"""NOT thread safe, do not call this method from multiple threads."""
session_create = TracerSessionCreate(name=name, extra=kwargs)
session = self._persist_session(session_create)
self._session = session
self.session = session
return session
@abstractmethod
@ -61,283 +67,248 @@ class BaseTracer(BaseCallbackHandler, ABC):
def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session."""
@property
@abstractmethod
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
@property
@abstractmethod
def _execution_order(self) -> int:
"""Get the execution order for a run."""
@_execution_order.setter
@abstractmethod
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
@property
@abstractmethod
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
@_session.setter
@abstractmethod
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Start a trace for a run."""
self._execution_order += 1
if self._stack:
if not (
isinstance(self._stack[-1], ChainRun)
or isinstance(self._stack[-1], ToolRun)
):
if run.parent_uuid:
parent_run = self.run_map[run.parent_uuid]
if parent_run:
if isinstance(parent_run, LLMRun):
raise TracerException(
f"Nested {run.__class__.__name__} can only be"
f" logged inside a ChainRun or ToolRun"
"Cannot add child run to an LLM run. "
"LLM runs are not allowed to have children."
)
self._add_child_run(parent_run, run)
else:
raise TracerException(
f"Parent run with UUID {run.parent_uuid} not found."
)
self._add_child_run(self._stack[-1], run)
self._stack.append(run)
def _end_trace(self) -> None:
self.run_map[run.uuid] = run
def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""End a trace for a run."""
run = self._stack.pop()
if not self._stack:
self._execution_order = 1
if not run.parent_uuid:
self._persist_run(run)
else:
parent_run = self.run_map.get(run.parent_uuid)
if parent_run is None:
raise TracerException(
f"Parent run with UUID {run.parent_uuid} not found."
)
if isinstance(parent_run, LLMRun):
raise TracerException("LLM Runs are not allowed to have children. ")
if run.child_execution_order > parent_run.child_execution_order:
parent_run.child_execution_order = run.child_execution_order
self.run_map.pop(run.uuid)
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
"""Get the execution order for a run."""
if parent_run_id is None:
return 1
parent_run = self.run_map.get(parent_run_id)
if parent_run is None:
raise TracerException(f"Parent run with UUID {parent_run_id} not found.")
if isinstance(parent_run, LLMRun):
raise TracerException("LLM Runs are not allowed to have children. ")
return parent_run.child_execution_order + 1
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
llm_run = LLMRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
serialized=serialized,
prompts=prompts,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
session_id=self._session.id,
id=self._generate_id(),
execution_order=execution_order,
child_execution_order=execution_order,
session_id=self.session.id,
)
self._start_trace(llm_run)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Handle a new token for an LLM run."""
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for an LLM run."""
if not self._stack or not isinstance(self._stack[-1], LLMRun):
if not run_id:
raise TracerException("No run_id provided for on_llm_end callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].response = response
self._end_trace()
llm_run.response = response
llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Handle an error for an LLM run."""
if not self._stack or not isinstance(self._stack[-1], LLMRun):
if not run_id:
raise TracerException("No run_id provided for on_llm_error callback.")
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or not isinstance(llm_run, LLMRun):
raise TracerException("No LLMRun found to be traced")
self._stack[-1].error = repr(error)
self._stack[-1].end_time = datetime.utcnow()
self._end_trace()
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
self._end_trace(llm_run)
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a chain run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
chain_run = ChainRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
serialized=serialized,
inputs=inputs,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
session_id=self._session.id,
id=self._generate_id(),
session_id=self.session.id,
)
self._start_trace(chain_run)
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
def on_chain_end(
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
) -> None:
"""End a trace for a chain run."""
if not self._stack or not isinstance(self._stack[-1], ChainRun):
run_id_ = str(run_id)
chain_run = self.run_map.get(run_id_)
if chain_run is None or not isinstance(chain_run, ChainRun):
raise TracerException("No ChainRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].outputs = outputs
self._end_trace()
chain_run.outputs = outputs
chain_run.end_time = datetime.utcnow()
self._end_trace(chain_run)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Handle an error for a chain run."""
if not self._stack or not isinstance(self._stack[-1], ChainRun):
run_id_ = str(run_id)
chain_run = self.run_map.get(run_id_)
if chain_run is None or not isinstance(chain_run, ChainRun):
raise TracerException("No ChainRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].error = repr(error)
self._end_trace()
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
self._end_trace(chain_run)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
self,
serialized: Dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a tool run."""
if self._session is None:
raise TracerException(
"Initialize a session with `new_session()` before starting a trace."
)
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
tool_run = ToolRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
serialized=serialized,
# TODO: this is duplicate info as above, not needed.
action=str(serialized),
tool_input=input_str,
extra=kwargs,
start_time=datetime.utcnow(),
execution_order=self._execution_order,
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
session_id=self._session.id,
id=self._generate_id(),
session_id=self.session.id,
)
self._start_trace(tool_run)
def on_tool_end(self, output: str, **kwargs: Any) -> None:
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run."""
if not self._stack or not isinstance(self._stack[-1], ToolRun):
run_id_ = str(run_id)
tool_run = self.run_map.get(run_id_)
if tool_run is None or not isinstance(tool_run, ToolRun):
raise TracerException("No ToolRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].output = output
self._end_trace()
tool_run.output = output
tool_run.end_time = datetime.utcnow()
self._end_trace(tool_run)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
error: Union[Exception, KeyboardInterrupt],
*,
run_id: UUID,
**kwargs: Any,
) -> None:
"""Handle an error for a tool run."""
if not self._stack or not isinstance(self._stack[-1], ToolRun):
run_id_ = str(run_id)
tool_run = self.run_map.get(run_id_)
if tool_run is None or not isinstance(tool_run, ToolRun):
raise TracerException("No ToolRun found to be traced")
self._stack[-1].end_time = datetime.utcnow()
self._stack[-1].error = repr(error)
tool_run.error = repr(error)
tool_run.end_time = datetime.utcnow()
self._end_trace(tool_run)
self._end_trace()
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
return self
def on_text(self, text: str, **kwargs: Any) -> None:
"""Handle a text message."""
pass
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None:
"""Handle an agent finish message."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Do nothing."""
pass
class Tracer(BaseTracer, ABC):
"""A non-thread safe implementation of the BaseTracer interface."""
def __init__(self) -> None:
"""Initialize a tracer."""
self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = []
self._tracer_execution_order = 1
self._tracer_session: Optional[TracerSession] = None
@property
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
return self._tracer_stack
@property
def _execution_order(self) -> int:
"""Get the execution order for a run."""
return self._tracer_execution_order
@_execution_order.setter
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
self._tracer_execution_order = value
@property
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
return self._tracer_session
@_session.setter
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
if self._stack:
raise TracerException(
"Cannot set a session while a trace is being recorded"
)
self._tracer_session = value
@dataclass
class TracerStack(threading.local):
"""A stack of runs used for logging."""
stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list)
execution_order: int = 1
class SharedTracer(Singleton, BaseTracer, ABC):
"""A thread-safe Singleton implementation of BaseTracer."""
_tracer_stack = TracerStack()
_tracer_session = None
@property
def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]:
"""Get the tracer stack."""
return self._tracer_stack.stack
@property
def _execution_order(self) -> int:
"""Get the execution order for a run."""
return self._tracer_stack.execution_order
@_execution_order.setter
def _execution_order(self, value: int) -> None:
"""Set the execution order for a run."""
self._tracer_stack.execution_order = value
@property
def _session(self) -> Optional[TracerSession]:
"""Get the tracing session."""
return self._tracer_session
@_session.setter
def _session(self, value: TracerSession) -> None:
"""Set the tracing session."""
with self._lock:
# TODO: currently, we are only checking current thread's stack.
# Need to make sure that we are not in the middle of a trace
# in any thread.
if self._stack:
raise TracerException(
"Cannot set a session while a trace is being recorded"
)
self._tracer_session = value
def __copy__(self) -> BaseTracer:
"""Copy the tracer."""
return self

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import logging
import os
from abc import ABC
from typing import Any, Dict, Optional, Union
import requests
@ -18,14 +17,17 @@ from langchain.callbacks.tracers.schemas import (
)
class BaseLangChainTracer(BaseTracer, ABC):
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
always_verbose: bool = True
_endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
_headers: Dict[str, Any] = {"Content-Type": "application/json"}
def __init__(self, session_name: str = "default", **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
self._headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
_headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
self.session = self.load_session(session_name)
def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
@ -59,54 +61,29 @@ class BaseLangChainTracer(BaseTracer, ABC):
session = TracerSession(id=1, **session_create.dict())
return session
def load_session(self, session_name: str) -> TracerSession:
def _load_session(self, session_name: Optional[str] = None) -> TracerSession:
"""Load a session from the tracer."""
try:
r = requests.get(
f"{self._endpoint}/sessions?name={session_name}",
headers=self._headers,
)
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)
tracer_session = TracerSession(**r.json()[0])
self._session = tracer_session
return tracer_session
except Exception as e:
session_type = "default" if not session_name else session_name
logging.warning(
f"Failed to load session {session_name}, using empty session: {e}"
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSession(id=1)
self._session = tracer_session
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> TracerSession:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> TracerSession:
"""Load the default tracing session and set it as the Tracer's session."""
try:
r = requests.get(
f"{self._endpoint}/sessions",
headers=self._headers,
)
# Use the first session result
tracer_session = TracerSession(**r.json()[0])
self._session = tracer_session
return tracer_session
except Exception as e:
logging.warning(f"Failed to default session, using empty session: {e}")
tracer_session = TracerSession(id=1)
self._session = tracer_session
return tracer_session
def _add_child_run(
self,
parent_run: Union[ChainRun, ToolRun],
child_run: Union[LLMRun, ChainRun, ToolRun],
) -> None:
"""Add child run to a chain run or tool run."""
if isinstance(child_run, LLMRun):
parent_run.child_llm_runs.append(child_run)
elif isinstance(child_run, ChainRun):
parent_run.child_chain_runs.append(child_run)
else:
parent_run.child_tool_runs.append(child_run)
def _generate_id(self) -> Optional[Union[int, str]]:
"""Generate an id for a run."""
return None
return self._load_session("default")

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
@ -32,11 +32,13 @@ class TracerSession(TracerSessionBase):
class BaseRun(BaseModel):
"""Base class for Run."""
id: Optional[Union[int, str]] = None
uuid: str
parent_uuid: Optional[str] = None
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: Optional[Dict[str, Any]] = None
execution_order: int
child_execution_order: int
serialized: Dict[str, Any]
session_id: int
error: Optional[str] = None
@ -57,7 +59,6 @@ class ChainRun(BaseRun):
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
class ToolRun(BaseRun):
@ -69,7 +70,6 @@ class ToolRun(BaseRun):
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list)
ChainRun.update_forward_refs()

View File

@ -5,12 +5,16 @@ from typing import Any, Dict, List, Optional
from pydantic import Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.requests import TextRequestsWrapper
from langchain.schema import BaseLanguageModel
class APIChain(Chain):
@ -61,16 +65,21 @@ class APIChain(Chain):
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
api_url = self.api_request_chain.predict(
question=question, api_docs=self.api_docs
)
self.callback_manager.on_text(
api_url, color="green", end="\n", verbose=self.verbose
question=question,
api_docs=self.api_docs,
callbacks=_run_manager.get_child(),
)
_run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose)
api_response = self.requests_wrapper.get(api_url)
self.callback_manager.on_text(
_run_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose
)
answer = self.api_answer_chain.predict(
@ -78,19 +87,27 @@ class APIChain(Chain):
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
callbacks=_run_manager.get_child(),
)
return {self.output_key: answer}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.question_key]
api_url = await self.api_request_chain.apredict(
question=question, api_docs=self.api_docs
question=question,
api_docs=self.api_docs,
callbacks=_run_manager.get_child(),
)
self.callback_manager.on_text(
await _run_manager.on_text(
api_url, color="green", end="\n", verbose=self.verbose
)
api_response = await self.requests_wrapper.aget(api_url)
self.callback_manager.on_text(
await _run_manager.on_text(
api_response, color="yellow", end="\n", verbose=self.verbose
)
answer = await self.api_answer_chain.apredict(
@ -98,6 +115,7 @@ class APIChain(Chain):
api_docs=self.api_docs,
api_url=api_url,
api_response=api_response,
callbacks=_run_manager.get_child(),
)
return {self.output_key: answer}

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast
from pydantic import BaseModel, Field
from requests import Response
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.api.openapi.requests_chain import APIRequesterChain
from langchain.chains.api.openapi.response_chain import APIResponderChain
from langchain.chains.base import Chain
@ -106,16 +107,21 @@ class OpenAPIEndpointChain(Chain, BaseModel):
else:
return {self.output_key: output}
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
intermediate_steps = {}
instructions = inputs[self.instructions_key]
instructions = instructions[: self.max_text_length]
_api_arguments = self.api_request_chain.predict_and_parse(
instructions=instructions
instructions=instructions, callbacks=_run_manager.get_child()
)
api_arguments = cast(str, _api_arguments)
intermediate_steps["request_args"] = api_arguments
self.callback_manager.on_text(
_run_manager.on_text(
api_arguments, color="green", end="\n", verbose=self.verbose
)
if api_arguments.startswith("ERROR"):
@ -141,18 +147,17 @@ class OpenAPIEndpointChain(Chain, BaseModel):
response_text = f"Error with message {str(e)}"
response_text = response_text[: self.max_text_length]
intermediate_steps["response_text"] = response_text
self.callback_manager.on_text(
_run_manager.on_text(
response_text, color="blue", end="\n", verbose=self.verbose
)
if self.api_response_chain is not None:
_answer = self.api_response_chain.predict_and_parse(
response=response_text,
instructions=instructions,
callbacks=_run_manager.get_child(),
)
answer = cast(str, _answer)
self.callback_manager.on_text(
answer, color="yellow", end="\n", verbose=self.verbose
)
_run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose)
return self._get_output(answer, intermediate_steps)
else:
return self._get_output(response_text, intermediate_steps)
@ -188,6 +193,7 @@ class OpenAPIEndpointChain(Chain, BaseModel):
verbose: bool = False,
return_intermediate_steps: bool = False,
raw_response: bool = False,
callbacks: Callbacks = None,
**kwargs: Any
# TODO: Handle async
) -> "OpenAPIEndpointChain":
@ -198,12 +204,17 @@ class OpenAPIEndpointChain(Chain, BaseModel):
path_params=operation.path_params,
)
requests_chain = APIRequesterChain.from_llm_and_typescript(
llm, typescript_definition=operation.to_typescript(), verbose=verbose
llm,
typescript_definition=operation.to_typescript(),
verbose=verbose,
callbacks=callbacks,
)
if raw_response:
response_chain = None
else:
response_chain = APIResponderChain.from_llm(llm, verbose=verbose)
response_chain = APIResponderChain.from_llm(
llm, verbose=verbose, callbacks=callbacks
)
_requests = requests or Requests()
return cls(
api_request_chain=requests_chain,
@ -213,5 +224,6 @@ class OpenAPIEndpointChain(Chain, BaseModel):
param_mapping=param_mapping,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
callbacks=callbacks,
**kwargs,
)

View File

@ -2,6 +2,7 @@
import json
import re
from typing import Any
from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE
from langchain.chains.llm import LLMChain
@ -36,7 +37,11 @@ class APIRequesterChain(LLMChain):
@classmethod
def from_llm_and_typescript(
cls, llm: BaseLLM, typescript_definition: str, verbose: bool = True
cls,
llm: BaseLLM,
typescript_definition: str,
verbose: bool = True,
**kwargs: Any,
) -> LLMChain:
"""Get the request parser."""
output_parser = APIRequesterOutputParser()
@ -46,4 +51,4 @@ class APIRequesterChain(LLMChain):
partial_variables={"schema": typescript_definition},
input_variables=["instructions"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)

View File

@ -2,6 +2,7 @@
import json
import re
from typing import Any
from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE
from langchain.chains.llm import LLMChain
@ -35,7 +36,7 @@ class APIResponderChain(LLMChain):
"""Get the response parser."""
@classmethod
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain:
"""Get the response parser."""
output_parser = APIResponderOutputParser()
prompt = PromptTemplate(
@ -43,4 +44,4 @@ class APIResponderChain(LLMChain):
output_parser=output_parser,
input_variables=["response", "instructions"],
)
return cls(prompt=prompt, llm=llm, verbose=verbose)
return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs)

View File

@ -1,15 +1,23 @@
"""Base interface that all chains should implement."""
import inspect
import json
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.schema import BaseMemory
@ -21,9 +29,8 @@ class Chain(BaseModel, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
callback_manager: BaseCallbackManager = Field(
default_factory=get_callback_manager, exclude=True
)
callbacks: Callbacks = None
callback_manager: Optional[BaseCallbackManager] = None
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
@ -37,15 +44,16 @@ class Chain(BaseModel, ABC):
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
@ -82,15 +90,26 @@ class Chain(BaseModel, ABC):
)
@abstractmethod
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
raise NotImplementedError("Async call not supported for this chain type.")
def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -104,21 +123,31 @@ class Chain(BaseModel, ABC):
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = self._call(inputs)
outputs = (
self._call(inputs, run_manager=run_manager)
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
run_manager.on_chain_error(e)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
async def acall(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -132,30 +161,24 @@ class Chain(BaseModel, ABC):
"""
inputs = self.prep_inputs(inputs)
if self.callback_manager.is_async:
await self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
else:
self.callback_manager.on_chain_start(
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = await self._acall(inputs)
outputs = (
await self._acall(inputs, run_manager=run_manager)
if new_arg_supported
else await self._acall(inputs)
)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_chain_error(e, verbose=self.verbose)
else:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
await run_manager.on_chain_error(e)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
else:
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
await run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
def prep_outputs(
@ -195,11 +218,13 @@ class Chain(BaseModel, ABC):
self._validate_inputs(inputs)
return inputs
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list]
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(self, *args: Any, **kwargs: Any) -> str:
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
@ -210,17 +235,17 @@ class Chain(BaseModel, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0])[self.output_keys[0]]
return self(args[0], callbacks=callbacks)[self.output_keys[0]]
if kwargs and not args:
return self(kwargs)[self.output_keys[0]]
return self(kwargs, callbacks=callbacks)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
async def arun(self, *args: Any, **kwargs: Any) -> str:
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
@ -231,10 +256,10 @@ class Chain(BaseModel, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (await self.acall(args[0]))[self.output_keys[0]]
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]]
if kwargs and not args:
return (await self.acall(kwargs))[self.output_keys[0]]
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"

View File

@ -5,6 +5,10 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
@ -68,19 +72,33 @@ class BaseCombineDocumentsChain(Chain, ABC):
) -> Tuple[str, dict]:
"""Combine documents into a single string asynchronously."""
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = self.combine_docs(docs, **other_keys)
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
extra_return_dict[self.output_key] = output
return extra_return_dict
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, List[Document]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
output, extra_return_dict = await self.acombine_docs(docs, **other_keys)
output, extra_return_dict = await self.acombine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
extra_return_dict[self.output_key] = output
return extra_return_dict
@ -108,10 +126,17 @@ class AnalyzeDocumentChain(Chain):
"""
return self.combine_docs_chain.output_keys
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
document = inputs[self.input_key]
docs = self.text_splitter.create_documents([document])
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
other_keys[self.combine_docs_chain.input_key] = docs
return self.combine_docs_chain(other_keys, return_only_outputs=True)
return self.combine_docs_chain(
other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
)

View File

@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
from pydantic import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
@ -129,7 +130,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return self.combine_document_chain
def combine_docs(
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
"""Combine documents in a map reduce manner.
@ -138,12 +143,15 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""
results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(
results, docs, token_max, callbacks=callbacks, **kwargs
)
return self._process_results(results, docs, token_max, **kwargs)
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents in a map reduce manner.
@ -152,15 +160,17 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
"""
results = await self.llm_chain.aapply(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(results, docs, **kwargs)
return self._process_results(results, docs, callbacks=callbacks, **kwargs)
def _process_results(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
question_result_key = self.llm_chain.output_key
@ -173,7 +183,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(input_documents=docs, **kwargs)
return self._collapse_chain.run(
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
@ -191,7 +203,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
extra_return_dict = {"intermediate_steps": _results}
else:
extra_return_dict = {}
output = self.combine_document_chain.run(input_documents=result_docs, **kwargs)
output = self.combine_document_chain.run(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
return output, extra_return_dict
@property

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from pydantic import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
@ -89,19 +90,22 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
)
return values
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents in a map rerank manner.
Combine by mapping first chain over all documents, then reranking the results.
"""
results = self.llm_chain.apply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(docs, results)
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents in a map rerank manner.
@ -109,7 +113,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
"""
results = await self.llm_chain.aapply_and_parse(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(docs, results)

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Tuple
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
@ -85,29 +86,31 @@ class RefineDocumentsChain(BaseCombineDocumentsChain):
)
return values
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
inputs = self._construct_initial_inputs(docs, **kwargs)
res = self.initial_llm_chain.predict(**inputs)
res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps = [res]
for doc in docs[1:]:
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
res = self.refine_llm_chain.predict(**inputs)
res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
refine_steps.append(res)
return self._construct_result(refine_steps, res)
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Combine by mapping first chain over all, then stuffing into final chain."""
inputs = self._construct_initial_inputs(docs, **kwargs)
res = await self.initial_llm_chain.apredict(**inputs)
res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps = [res]
for doc in docs[1:]:
base_inputs = self._construct_refine_inputs(doc, res)
inputs = {**base_inputs, **kwargs}
res = await self.refine_llm_chain.apredict(**inputs)
res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
refine_steps.append(res)
return self._construct_result(refine_steps, res)

View File

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import (
BaseCombineDocumentsChain,
format_document,
@ -77,19 +78,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(**inputs), {}
async def acombine_docs(
self, docs: List[Document], **kwargs: Any
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return await self.llm_chain.apredict(**inputs), {}
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {}
@property
def _chain_type(self) -> str:

View File

@ -1,13 +1,14 @@
"""Chain for applying constitutional principles to the outputs of another chain."""
from typing import Any, Dict, List, Optional
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.constitutional_ai.principles import PRINCIPLES
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class ConstitutionalChain(Chain):
@ -86,11 +87,16 @@ class ConstitutionalChain(Chain):
"""Defines the output keys."""
return ["output"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
response = self.chain.run(**inputs)
input_prompt = self.chain.prompt.format(**inputs)
self.callback_manager.on_text(
_run_manager.on_text(
text="Initial response: " + response + "\n\n",
verbose=self.verbose,
color="yellow",
@ -103,6 +109,7 @@ class ConstitutionalChain(Chain):
input_prompt=input_prompt,
output_from_model=response,
critique_request=constitutional_principle.critique_request,
callbacks=_run_manager.get_child(),
)
critique = self._parse_critique(
output_string=raw_critique,
@ -116,22 +123,23 @@ class ConstitutionalChain(Chain):
critique_request=constitutional_principle.critique_request,
critique=critique,
revision_request=constitutional_principle.revision_request,
callbacks=_run_manager.get_child(),
).strip()
response = revision
self.callback_manager.on_text(
_run_manager.on_text(
text=f"Applying {constitutional_principle.name}..." + "\n\n",
verbose=self.verbose,
color="green",
)
self.callback_manager.on_text(
_run_manager.on_text(
text="Critique: " + critique + "\n\n",
verbose=self.verbose,
color="blue",
)
self.callback_manager.on_text(
_run_manager.on_text(
text="Updated response: " + revision + "\n\n",
verbose=self.verbose,
color="yellow",

View File

@ -8,6 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
@ -15,7 +20,7 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document
from langchain.schema import BaseMessage, BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.
@ -81,14 +86,20 @@ class BaseConversationalRetrievalChain(Chain):
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs."""
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
if chat_history_str:
callbacks = _run_manager.get_child()
new_question = self.question_generator.run(
question=question, chat_history=chat_history_str
question=question, chat_history=chat_history_str, callbacks=callbacks
)
else:
new_question = question
@ -96,7 +107,9 @@ class BaseConversationalRetrievalChain(Chain):
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = self.combine_docs_chain.run(input_documents=docs, **new_inputs)
answer = self.combine_docs_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
@ -106,13 +119,19 @@ class BaseConversationalRetrievalChain(Chain):
async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs."""
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs["question"]
get_chat_history = self.get_chat_history or _get_chat_history
chat_history_str = get_chat_history(inputs["chat_history"])
if chat_history_str:
callbacks = _run_manager.get_child()
new_question = await self.question_generator.arun(
question=question, chat_history=chat_history_str
question=question, chat_history=chat_history_str, callbacks=callbacks
)
else:
new_question = question
@ -120,7 +139,9 @@ class BaseConversationalRetrievalChain(Chain):
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = await self.combine_docs_chain.arun(input_documents=docs, **new_inputs)
answer = await self.combine_docs_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:

View File

@ -1,10 +1,11 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
from langchain.chains.llm import LLMChain
@ -51,18 +52,25 @@ class GraphQAChain(Chain):
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
return cls(qa_chain=qa_chain, entity_extraction_chain=entity_chain, **kwargs)
return cls(
qa_chain=qa_chain,
entity_extraction_chain=entity_chain,
**kwargs,
)
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Extract entities, look up info and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
entity_string = self.entity_extraction_chain.run(question)
self.callback_manager.on_text(
"Entities Extracted:", end="\n", verbose=self.verbose
)
self.callback_manager.on_text(
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
_run_manager.on_text(
entity_string, color="green", end="\n", verbose=self.verbose
)
entities = get_entities(entity_string)
@ -70,9 +78,10 @@ class GraphQAChain(Chain):
for entity in entities:
triplets = self.graph.get_entity_knowledge(entity)
context += "\n".join(triplets)
self.callback_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
self.callback_manager.on_text(
context, color="green", end="\n", verbose=self.verbose
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=_run_manager.get_child(),
)
result = self.qa_chain({"question": question, "context": context})
return {self.output_key: result[self.qa_chain.output_key]}

View File

@ -4,11 +4,12 @@ https://arxiv.org/abs/2212.10496
"""
from __future__ import annotations
from typing import Dict, List
from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.hyde.prompts import PROMPT_MAP
from langchain.chains.llm import LLMChain
@ -57,18 +58,27 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings):
embeddings = self.embed_documents(documents)
return self.combine_embeddings(embeddings)
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Call the internal llm chain."""
return self.llm_chain._call(inputs)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
return self.llm_chain(inputs, callbacks=_run_manager.get_child())
@classmethod
def from_llm(
cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str
cls,
llm: BaseLLM,
base_embeddings: Embeddings,
prompt_key: str,
**kwargs: Any,
) -> HypotheticalDocumentEmbedder:
"""Load and use LLMChain for a specific prompt key."""
prompt = PROMPT_MAP[prompt_key]
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain)
return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs)
@property
def _chain_type(self) -> str:

View File

@ -5,11 +5,19 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import Extra
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, LLMResult, PromptValue
from langchain.schema import LLMResult, PromptValue
class LLMChain(Chain):
@ -53,21 +61,40 @@ class LLMChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
return self.llm.generate_prompt(prompts, stop)
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
)
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
async def agenerate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list)
return await self.llm.agenerate_prompt(prompts, stop)
return await self.llm.agenerate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
)
def prep_prompts(
self, input_list: List[Dict[str, Any]]
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
@ -79,7 +106,8 @@ class LLMChain(Chain):
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if run_manager:
run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
@ -88,7 +116,9 @@ class LLMChain(Chain):
return prompts, stop
async def aprep_prompts(
self, input_list: List[Dict[str, Any]]
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Tuple[List[PromptValue], Optional[List[str]]]:
"""Prepare prompts from inputs."""
stop = None
@ -100,12 +130,8 @@ class LLMChain(Chain):
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if self.callback_manager.is_async:
await self.callback_manager.on_text(
_text, end="\n", verbose=self.verbose
)
else:
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
if run_manager:
await run_manager.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
@ -113,15 +139,45 @@ class LLMChain(Chain):
prompts.append(prompt)
return prompts, stop
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
return self.create_outputs(response)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__},
{"input_list": input_list},
)
try:
response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
run_manager.on_chain_end({"outputs": outputs})
return outputs
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
async def aapply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
return self.create_outputs(response)
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__},
{"input_list": input_list},
)
try:
response = await self.agenerate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
await run_manager.on_chain_end({"outputs": outputs})
return outputs
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
"""Create outputs from response."""
@ -131,13 +187,19 @@ class LLMChain(Chain):
for generation in response.generations
]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
response = await self.agenerate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def predict(self, **kwargs: Any) -> str:
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
@ -148,12 +210,13 @@ class LLMChain(Chain):
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
return self(kwargs, callbacks=callbacks)[self.output_key]
async def apredict(self, **kwargs: Any) -> str:
async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
@ -164,31 +227,33 @@ class LLMChain(Chain):
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs))[self.output_key]
return (await self.acall(kwargs, callbacks=callbacks))[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
async def apredict_and_parse(
self, **kwargs: Any
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call apredict and then parse the results."""
result = await self.apredict(**kwargs)
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
else:
return result
def apply_and_parse(
self, input_list: List[Dict[str, Any]]
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = self.apply(input_list)
result = self.apply(input_list, callbacks=callbacks)
return self._parse_result(result)
def _parse_result(
@ -202,10 +267,10 @@ class LLMChain(Chain):
return result
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]]
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
result = await self.aapply(input_list)
result = await self.aapply(input_list, callbacks=callbacks)
return self._parse_result(result)
@property

View File

@ -1,47 +1,24 @@
"""Chain that interprets a prompt and executes bash code to perform bash operations."""
from __future__ import annotations
import logging
import re
from typing import Any, Dict, List
import warnings
from typing import Any, Dict, List, Optional
from pydantic import Extra, Field
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
from langchain.schema import OutputParserException
from langchain.utilities.bash import BashProcess
logger = logging.getLogger(__name__)
class BashOutputParser(BaseOutputParser):
"""Parser for bash output."""
def parse(self, text: str) -> List[str]:
if "```bash" in text:
return self.get_code_blocks(text)
else:
raise OutputParserException(
f"Failed to parse bash output. Got: {text}",
)
@staticmethod
def get_code_blocks(t: str) -> List[str]:
"""Get multiple code blocks from the LLM result."""
code_blocks: List[str] = []
# Bash markdown code blocks
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
for match in pattern.finditer(t):
matched = match.group(1).strip()
if matched:
code_blocks.extend(
[line for line in matched.split("\n") if line.strip()]
)
return code_blocks
class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash code to perform bash operations.
@ -49,15 +26,16 @@ class LLMBashChain(Chain):
.. code-block:: python
from langchain import LLMBashChain, OpenAI
llm_bash = LLMBashChain(llm=OpenAI())
llm_bash = LLMBashChain.from_llm(OpenAI())
"""
llm: BaseLanguageModel
"""LLM wrapper to use."""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
prompt: BasePromptTemplate = PROMPT
output_parser: BaseOutputParser = Field(default_factory=BashOutputParser)
"""[Deprecated]"""
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
class Config:
@ -66,6 +44,26 @@ class LLMBashChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMBashChain with an llm is deprecated. "
"Please instantiate with llm_chain or using the from_llm class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@root_validator
def validate_prompt(cls, values: Dict) -> Dict:
if values["llm_chain"].prompt.output_parser is None:
raise ValueError(
"The prompt used by llm_chain is expected to have an output_parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
@ -82,30 +80,34 @@ class LLMBashChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=self.prompt, llm=self.llm)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
t = llm_executor.predict(question=inputs[self.input_key])
self.callback_manager.on_text(t, color="green", verbose=self.verbose)
t = self.llm_chain.predict(
question=inputs[self.input_key], callbacks=_run_manager.get_child()
)
_run_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()
try:
command_list = self.output_parser.parse(t)
parser = self.llm_chain.prompt.output_parser
command_list = parser.parse(t) # type: ignore[union-attr]
except OutputParserException as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
_run_manager.on_chain_error(e, verbose=self.verbose)
raise e
if self.verbose:
self.callback_manager.on_text("\nCode: ", verbose=self.verbose)
self.callback_manager.on_text(
_run_manager.on_text("\nCode: ", verbose=self.verbose)
_run_manager.on_text(
str(command_list), color="yellow", verbose=self.verbose
)
output = self.bash_process.run(command_list)
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
return {self.output_key: output}
@property
@ -113,11 +115,11 @@ class LLMBashChain(Chain):
return "llm_bash_chain"
@classmethod
def from_bash_process(
def from_llm(
cls,
bash_process: BashProcess,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> "LLMBashChain":
"""Create a LLMBashChain from a BashProcess."""
return cls(llm=llm, bash_process=bash_process, **kwargs)
) -> LLMBashChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)

View File

@ -1,5 +1,11 @@
# flake8: noqa
from __future__ import annotations
import re
from typing import List
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:
@ -19,4 +25,36 @@ That is the format. Begin!
Question: {question}"""
PROMPT = PromptTemplate(input_variables=["question"], template=_PROMPT_TEMPLATE)
class BashOutputParser(BaseOutputParser):
"""Parser for bash output."""
def parse(self, text: str) -> List[str]:
if "```bash" in text:
return self.get_code_blocks(text)
else:
raise OutputParserException(
f"Failed to parse bash output. Got: {text}",
)
@staticmethod
def get_code_blocks(t: str) -> List[str]:
"""Get multiple code blocks from the LLM result."""
code_blocks: List[str] = []
# Bash markdown code blocks
pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
for match in pattern.finditer(t):
matched = match.group(1).strip()
if matched:
code_blocks.extend(
[line for line in matched.split("\n") if line.strip()]
)
return code_blocks
PROMPT = PromptTemplate(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
output_parser=BashOutputParser(),
)

View File

@ -1,10 +1,12 @@
"""Chain for question-answering with self-verification."""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from typing import Dict, List
from pydantic import Extra
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_checker.prompt import (
@ -18,6 +20,48 @@ from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
def _load_question_to_checked_assertions_chain(
llm: BaseLLM,
create_draft_answer_prompt: PromptTemplate,
list_assertions_prompt: PromptTemplate,
check_assertions_prompt: PromptTemplate,
revised_answer_prompt: PromptTemplate,
) -> SequentialChain:
create_draft_answer_chain = LLMChain(
llm=llm,
prompt=create_draft_answer_prompt,
output_key="statement",
)
list_assertions_chain = LLMChain(
llm=llm,
prompt=list_assertions_prompt,
output_key="assertions",
)
check_assertions_chain = LLMChain(
llm=llm,
prompt=check_assertions_prompt,
output_key="checked_assertions",
)
revised_answer_chain = LLMChain(
llm=llm,
prompt=revised_answer_prompt,
output_key="revised_statement",
)
chains = [
create_draft_answer_chain,
list_assertions_chain,
check_assertions_chain,
revised_answer_chain,
]
question_to_checked_assertions_chain = SequentialChain(
chains=chains,
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,
)
return question_to_checked_assertions_chain
class LLMCheckerChain(Chain):
"""Chain for question-answering with self-verification.
@ -26,16 +70,21 @@ class LLMCheckerChain(Chain):
from langchain import OpenAI, LLMCheckerChain
llm = OpenAI(temperature=0.7)
checker_chain = LLMCheckerChain(llm=llm)
checker_chain = LLMCheckerChain.from_llm(llm)
"""
llm: BaseLLM
"""LLM wrapper to use."""
question_to_checked_assertions_chain: SequentialChain
llm: Optional[BaseLLM] = None
"""[Deprecated] LLM wrapper to use."""
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
"""[Deprecated]"""
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
"""[Deprecated]"""
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
"""[Deprecated]"""
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT
"""Prompt to use when questioning the documents."""
"""[Deprecated] Prompt to use when questioning the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@ -45,6 +94,34 @@ class LLMCheckerChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMCheckerChain with an llm is deprecated. "
"Please instantiate with question_to_checked_assertions_chain "
"or using the from_llm class method."
)
if (
"question_to_checked_assertions_chain" not in values
and values["llm"] is not None
):
question_to_checked_assertions_chain = (
_load_question_to_checked_assertions_chain(
values["llm"],
values.get(
"create_draft_answer_prompt", CREATE_DRAFT_ANSWER_PROMPT
),
values.get("list_assertions_prompt", LIST_ASSERTIONS_PROMPT),
values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT),
values.get("revised_answer_prompt", REVISED_ANSWER_PROMPT),
)
)
values[
"question_to_checked_assertions_chain"
] = question_to_checked_assertions_chain
return values
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
@ -61,43 +138,43 @@ class LLMCheckerChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
create_draft_answer_chain = LLMChain(
llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement"
output = self.question_to_checked_assertions_chain(
{"question": question}, callbacks=_run_manager.get_child()
)
list_assertions_chain = LLMChain(
llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions"
)
check_assertions_chain = LLMChain(
llm=self.llm,
prompt=self.check_assertions_prompt,
output_key="checked_assertions",
)
revised_answer_chain = LLMChain(
llm=self.llm,
prompt=self.revised_answer_prompt,
output_key="revised_statement",
)
chains = [
create_draft_answer_chain,
list_assertions_chain,
check_assertions_chain,
revised_answer_chain,
]
question_to_checked_assertions_chain = SequentialChain(
chains=chains,
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,
)
output = question_to_checked_assertions_chain({"question": question})
return {self.output_key: output["revised_statement"]}
@property
def _chain_type(self) -> str:
return "llm_checker_chain"
@classmethod
def from_llm(
cls,
llm: BaseLLM,
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT,
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT,
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT,
**kwargs: Any,
) -> LLMCheckerChain:
question_to_checked_assertions_chain = (
_load_question_to_checked_assertions_chain(
llm,
create_draft_answer_prompt,
list_assertions_prompt,
check_assertions_prompt,
revised_answer_prompt,
)
)
return cls(
question_to_checked_assertions_chain=question_to_checked_assertions_chain,
**kwargs,
)

View File

@ -1,16 +1,23 @@
"""Chain that interprets a prompt and executes python code to do math."""
from __future__ import annotations
import math
import re
from typing import Dict, List
import warnings
from typing import Any, Dict, List, Optional
import numexpr
from pydantic import Extra
from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LLMMathChain(Chain):
@ -20,13 +27,14 @@ class LLMMathChain(Chain):
.. code-block:: python
from langchain import LLMMathChain, OpenAI
llm_math = LLMMathChain(llm=OpenAI())
llm_math = LLMMathChain.from_llm(OpenAI())
"""
llm: BaseLanguageModel
"""LLM wrapper to use."""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
prompt: BasePromptTemplate = PROMPT
"""Prompt to use to translate to python if neccessary."""
"""[Deprecated] Prompt to use to translate to python if necessary."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
@ -36,6 +44,19 @@ class LLMMathChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMMathChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
@ -68,15 +89,17 @@ class LLMMathChain(Chain):
# Remove any leading and trailing brackets from the output
return re.sub(r"^\[|\]$", "", output)
def _process_llm_result(self, llm_output: str) -> Dict[str, str]:
self.callback_manager.on_text(llm_output, color="green", verbose=self.verbose)
def _process_llm_result(
self, llm_output: str, run_manager: CallbackManagerForChainRun
) -> Dict[str, str]:
run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
self.callback_manager.on_text(output, color="yellow", verbose=self.verbose)
run_manager.on_text("\nAnswer: ", verbose=self.verbose)
run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
@ -86,30 +109,19 @@ class LLMMathChain(Chain):
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]:
if self.callback_manager.is_async:
await self.callback_manager.on_text(
llm_output, color="green", verbose=self.verbose
)
else:
self.callback_manager.on_text(
llm_output, color="green", verbose=self.verbose
)
async def _aprocess_llm_result(
self,
llm_output: str,
run_manager: AsyncCallbackManagerForChainRun,
) -> Dict[str, str]:
await run_manager.on_text(llm_output, color="green", verbose=self.verbose)
llm_output = llm_output.strip()
text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL)
if text_match:
expression = text_match.group(1)
output = self._evaluate_expression(expression)
if self.callback_manager.is_async:
await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
await self.callback_manager.on_text(
output, color="yellow", verbose=self.verbose
)
else:
self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose)
self.callback_manager.on_text(
output, color="yellow", verbose=self.verbose
)
await run_manager.on_text("\nAnswer: ", verbose=self.verbose)
await run_manager.on_text(output, color="yellow", verbose=self.verbose)
answer = "Answer: " + output
elif llm_output.startswith("Answer:"):
answer = llm_output
@ -119,31 +131,44 @@ class LLMMathChain(Chain):
raise ValueError(f"unknown format from LLM: {llm_output}")
return {self.output_key: answer}
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key])
llm_output = self.llm_chain.predict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
llm_output = llm_executor.predict(
question=inputs[self.input_key], stop=["```output"]
)
return self._process_llm_result(llm_output)
return self._process_llm_result(llm_output, _run_manager)
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(
prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager
async def _acall(
self,
inputs: Dict[str, str],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
await _run_manager.on_text(inputs[self.input_key])
llm_output = await self.llm_chain.apredict(
question=inputs[self.input_key],
stop=["```output"],
callbacks=_run_manager.get_child(),
)
if self.callback_manager.is_async:
await self.callback_manager.on_text(
inputs[self.input_key], verbose=self.verbose
)
else:
self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose)
llm_output = await llm_executor.apredict(
question=inputs[self.input_key], stop=["```output"]
)
return await self._aprocess_llm_result(llm_output)
return await self._aprocess_llm_result(llm_output, _run_manager)
@property
def _chain_type(self) -> str:
return "llm_math_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMMathChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)

View File

@ -1,10 +1,11 @@
"""Chain that hits a URL and then uses an LLM to parse results."""
from __future__ import annotations
from typing import Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from langchain.requests import TextRequestsWrapper
@ -61,9 +62,14 @@ class LLMRequestsChain(Chain):
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
from bs4 import BeautifulSoup
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
url = inputs[self.input_key]
@ -71,7 +77,9 @@ class LLMRequestsChain(Chain):
# extract the text from the html
soup = BeautifulSoup(res, "html.parser")
other_keys[self.requests_key] = soup.get_text()[: self.text_length]
result = self.llm_chain.predict(**other_keys)
result = self.llm_chain.predict(
callbacks=_run_manager.get_child(), **other_keys
)
return {self.output_key: result}
@property

View File

@ -1,10 +1,14 @@
"""Chain for summarization with self-verification."""
from __future__ import annotations
import warnings
from pathlib import Path
from typing import Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Extra
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SequentialChain
@ -27,6 +31,48 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file(
)
def _load_sequential_chain(
llm: BaseLLM,
create_assertions_prompt: PromptTemplate,
check_assertions_prompt: PromptTemplate,
revised_summary_prompt: PromptTemplate,
are_all_true_prompt: PromptTemplate,
verbose: bool = False,
) -> SequentialChain:
chain = SequentialChain(
chains=[
LLMChain(
llm=llm,
prompt=create_assertions_prompt,
output_key="assertions",
verbose=verbose,
),
LLMChain(
llm=llm,
prompt=check_assertions_prompt,
output_key="checked_assertions",
verbose=verbose,
),
LLMChain(
llm=llm,
prompt=revised_summary_prompt,
output_key="revised_summary",
verbose=verbose,
),
LLMChain(
llm=llm,
output_key="all_true",
prompt=are_all_true_prompt,
verbose=verbose,
),
],
input_variables=["summary"],
output_variables=["all_true", "revised_summary"],
verbose=verbose,
)
return chain
class LLMSummarizationCheckerChain(Chain):
"""Chain for question-answering with self-verification.
@ -35,16 +81,21 @@ class LLMSummarizationCheckerChain(Chain):
from langchain import OpenAI, LLMSummarizationCheckerChain
llm = OpenAI(temperature=0.0)
checker_chain = LLMSummarizationCheckerChain(llm=llm)
checker_chain = LLMSummarizationCheckerChain.from_llm(llm)
"""
llm: BaseLLM
"""LLM wrapper to use."""
sequential_chain: SequentialChain
llm: Optional[BaseLLM] = None
"""[Deprecated] LLM wrapper to use."""
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT
"""[Deprecated]"""
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
"""[Deprecated]"""
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT
"""[Deprecated]"""
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT
"""[Deprecated]"""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@ -57,6 +108,25 @@ class LLMSummarizationCheckerChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMSummarizationCheckerChain with an llm is "
"deprecated. Please instantiate with"
" sequential_chain argument or using the from_llm class method."
)
if "sequential_chain" not in values and values["llm"] is not None:
values["sequential_chain"] = _load_sequential_chain(
values["llm"],
values.get("create_assertions_prompt", CREATE_ASSERTIONS_PROMPT),
values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT),
values.get("revised_summary_prompt", REVISED_SUMMARY_PROMPT),
values.get("are_all_true_prompt", ARE_ALL_TRUE_PROMPT),
verbose=values.get("verbose", False),
)
return values
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
@ -73,46 +143,21 @@ class LLMSummarizationCheckerChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
all_true = False
count = 0
output = None
original_input = inputs[self.input_key]
chain_input = original_input
while not all_true and count < self.max_checks:
chain = SequentialChain(
chains=[
LLMChain(
llm=self.llm,
prompt=self.create_assertions_prompt,
output_key="assertions",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
prompt=self.check_assertions_prompt,
output_key="checked_assertions",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
prompt=self.revised_summary_prompt,
output_key="revised_summary",
verbose=self.verbose,
),
LLMChain(
llm=self.llm,
output_key="all_true",
prompt=self.are_all_true_prompt,
verbose=self.verbose,
),
],
input_variables=["summary"],
output_variables=["all_true", "revised_summary"],
verbose=self.verbose,
output = self.sequential_chain(
{"summary": chain_input}, callbacks=_run_manager.get_child()
)
output = chain({"summary": chain_input})
count += 1
if output["all_true"].strip() == "True":
@ -131,3 +176,24 @@ class LLMSummarizationCheckerChain(Chain):
@property
def _chain_type(self) -> str:
return "llm_summarization_checker_chain"
@classmethod
def from_llm(
cls,
llm: BaseLLM,
create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT,
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT,
revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT,
are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT,
verbose: bool = False,
**kwargs: Any,
) -> LLMSummarizationCheckerChain:
chain = _load_sequential_chain(
llm,
create_assertions_prompt,
check_assertions_prompt,
revised_summary_prompt,
are_all_true_prompt,
verbose=verbose,
)
return cls(sequential_chain=chain, verbose=verbose, **kwargs)

View File

@ -5,10 +5,11 @@ then combines the results with another one.
"""
from __future__ import annotations
from typing import Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -32,16 +33,26 @@ class MapReduceChain(Chain):
@classmethod
def from_params(
cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter
cls,
llm: BaseLLM,
prompt: BasePromptTemplate,
text_splitter: TextSplitter,
callbacks: Callbacks = None,
**kwargs: Any,
) -> MapReduceChain:
"""Construct a map-reduce chain that uses the chain for map and reduce."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain)
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks)
reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks)
combine_documents_chain = MapReduceDocumentsChain(
llm_chain=llm_chain, combine_document_chain=reduce_chain
llm_chain=llm_chain,
combine_document_chain=reduce_chain,
callbacks=callbacks,
)
return cls(
combine_documents_chain=combine_documents_chain, text_splitter=text_splitter
combine_documents_chain=combine_documents_chain,
text_splitter=text_splitter,
callbacks=callbacks,
**kwargs,
)
class Config:
@ -66,9 +77,16 @@ class MapReduceChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
# Split the larger text into smaller chunks.
texts = self.text_splitter.split_text(inputs[self.input_key])
docs = [Document(page_content=text) for text in texts]
outputs = self.combine_documents_chain.run(input_documents=docs)
outputs = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child()
)
return {self.output_key: outputs}

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.utils import get_from_dict_or_env
@ -84,7 +85,11 @@ class OpenAIModerationChain(Chain):
return error_str
return text
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
text = inputs[self.input_key]
results = self.client.create(text)
output = self._moderate(text, results["results"][0])

View File

@ -1,10 +1,12 @@
"""Implement an LLM driven browser."""
from __future__ import annotations
from typing import Dict, List
import warnings
from typing import Any, Dict, List, Optional
from pydantic import Extra
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.natbot.prompt import PROMPT
@ -18,14 +20,15 @@ class NatBotChain(Chain):
Example:
.. code-block:: python
from langchain import NatBotChain, OpenAI
natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.")
from langchain import NatBotChain
natbot = NatBotChain.from_default("Buy me a new hat.")
"""
llm: BaseLLM
"""LLM wrapper to use."""
llm_chain: LLMChain
objective: str
"""Objective that NatBot is tasked with completing."""
llm: Optional[BaseLLM] = None
"""[Deprecated] LLM wrapper to use."""
input_url_key: str = "url" #: :meta private:
input_browser_content_key: str = "browser_content" #: :meta private:
previous_command: str = "" #: :meta private:
@ -37,11 +40,29 @@ class NatBotChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an NatBotChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT)
return values
@classmethod
def from_default(cls, objective: str) -> NatBotChain:
"""Load with default LLM."""
def from_default(cls, objective: str, **kwargs: Any) -> NatBotChain:
"""Load with default LLMChain."""
llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)
return cls(llm=llm, objective=objective)
return cls.from_llm(llm, objective, **kwargs)
@classmethod
def from_llm(cls, llm: BaseLLM, objective: str, **kwargs: Any) -> NatBotChain:
"""Load from LLM."""
llm_chain = LLMChain(llm=llm, prompt=PROMPT)
return cls(llm_chain=llm_chain, objective=objective, **kwargs)
@property
def input_keys(self) -> List[str]:
@ -59,15 +80,20 @@ class NatBotChain(Chain):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
url = inputs[self.input_url_key]
browser_content = inputs[self.input_browser_content_key]
llm_cmd = llm_executor.predict(
llm_cmd = self.llm_chain.predict(
objective=self.objective,
url=url[:100],
previous_command=self.previous_command,
browser_content=browser_content[:4500],
callbacks=_run_manager.get_child(),
)
llm_cmd = llm_cmd.strip()
self.previous_command = llm_cmd

View File

@ -4,24 +4,29 @@ As in https://arxiv.org/pdf/2211.10435.pdf.
"""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from pydantic import Extra
from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.utilities import PythonREPL
class PALChain(Chain):
"""Implements Program-Aided Language Models."""
llm: BaseLanguageModel
prompt: BasePromptTemplate
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated]"""
prompt: BasePromptTemplate = MATH_PROMPT
"""[Deprecated]"""
stop: str = "\n\n"
get_answer_expr: str = "print(solution())"
python_globals: Optional[Dict[str, Any]] = None
@ -35,6 +40,19 @@ class PALChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an PALChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the one of "
"the class method constructors from_math_prompt, "
"from_colored_object_prompt."
)
if "llm_chain" not in values and values["llm"] is not None:
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=MATH_PROMPT)
return values
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
@ -54,12 +72,16 @@ class PALChain(Chain):
else:
return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
code = llm_chain.predict(stop=[self.stop], **inputs)
self.callback_manager.on_text(
code, color="green", end="\n", verbose=self.verbose
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
code = self.llm_chain.predict(
stop=[self.stop], callbacks=_run_manager.get_child(), **inputs
)
_run_manager.on_text(code, color="green", end="\n", verbose=self.verbose)
repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals)
res = repl.run(code + f"\n{self.get_answer_expr}")
output = {self.output_key: res.strip()}
@ -70,9 +92,9 @@ class PALChain(Chain):
@classmethod
def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain:
"""Load PAL from math prompt."""
llm_chain = LLMChain(llm=llm, prompt=MATH_PROMPT)
return cls(
llm=llm,
prompt=MATH_PROMPT,
llm_chain=llm_chain,
stop="\n\n",
get_answer_expr="print(solution())",
**kwargs,
@ -83,9 +105,9 @@ class PALChain(Chain):
cls, llm: BaseLanguageModel, **kwargs: Any
) -> PALChain:
"""Load PAL from colored object prompt."""
llm_chain = LLMChain(llm=llm, prompt=COLORED_OBJECT_PROMPT)
return cls(
llm=llm,
prompt=COLORED_OBJECT_PROMPT,
llm_chain=llm_chain,
stop="\n\n\n",
get_answer_expr="print(answer)",
**kwargs,

View File

@ -3,10 +3,10 @@ from typing import Callable, List, Tuple
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC):

View File

@ -5,11 +5,12 @@ from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -45,11 +46,14 @@ class QAGenerationChain(Chain):
def output_keys(self) -> List[str]:
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, List]:
docs = self.text_splitter.create_documents([inputs[self.input_key]])
results = self.llm_chain.generate([{"text": d.page_content} for d in docs])
results = self.llm_chain.generate(
[{"text": d.page_content} for d in docs], run_manager=run_manager
)
qa = [json.loads(res[0].text) for res in results.generations]
return {self.output_key: qa}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
raise NotImplementedError

View File

@ -8,6 +8,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -21,7 +26,6 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
)
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BaseQAWithSourcesChain(Chain, ABC):
@ -114,9 +118,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs to run questioning over."""
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
docs = self._get_docs(inputs)
answer = self.combine_documents_chain.run(input_documents=docs, **inputs)
answer = self.combine_documents_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s", answer)
else:
@ -133,9 +144,16 @@ class BaseQAWithSourcesChain(Chain, ABC):
async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]:
"""Get docs to run questioning over."""
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
docs = await self._aget_docs(inputs)
answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs)
answer = await self.combine_documents_chain.arun(
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
)
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s", answer)
else:

View File

@ -1,6 +1,7 @@
"""Load question answering with sources chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@ -14,7 +15,6 @@ from langchain.chains.qa_with_sources import (
)
from langchain.chains.question_answering import map_rerank_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@ -5,6 +5,7 @@ import json
from typing import Any, Callable, List, Optional, Sequence
from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.chains.query_constructor.ir import (
Comparator,
Operator,
@ -20,7 +21,7 @@ from langchain.chains.query_constructor.prompt import (
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.structured import parse_json_markdown
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
from langchain.schema import BaseOutputParser, OutputParserException
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):

View File

@ -1,6 +1,7 @@
"""Load question answering chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@ -15,7 +16,6 @@ from langchain.chains.question_answering import (
stuff_prompt,
)
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@ -7,6 +7,11 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
@ -14,7 +19,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
@ -92,7 +97,11 @@ class BaseRetrievalQA(Chain):
def _get_docs(self, question: str) -> List[Document]:
"""Get documents to do question answering over."""
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
@ -104,11 +113,12 @@ class BaseRetrievalQA(Chain):
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = self._get_docs(question)
answer = self.combine_documents_chain.run(
input_documents=docs, question=question
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:
@ -120,7 +130,11 @@ class BaseRetrievalQA(Chain):
async def _aget_docs(self, question: str) -> List[Document]:
"""Get documents to do question answering over."""
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
@ -132,11 +146,12 @@ class BaseRetrievalQA(Chain):
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
docs = await self._aget_docs(question)
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
if self.return_source_documents:

View File

@ -1,8 +1,12 @@
"""Chain pipeline where the outputs of one step feed directly into next."""
from typing import Dict, List
from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.input import get_color_mapping
@ -86,17 +90,31 @@ class SequentialChain(Chain):
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
known_values = inputs.copy()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
for i, chain in enumerate(self.chains):
outputs = chain(known_values, return_only_outputs=True)
callbacks = _run_manager.get_child()
outputs = chain(known_values, return_only_outputs=True, callbacks=callbacks)
known_values.update(outputs)
return {k: known_values[k] for k in self.output_variables}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
known_values = inputs.copy()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
for i, chain in enumerate(self.chains):
outputs = await chain.acall(known_values, return_only_outputs=True)
outputs = await chain.acall(
known_values, return_only_outputs=True, callbacks=callbacks
)
known_values.update(outputs)
return {k: known_values[k] for k in self.output_variables}
@ -147,31 +165,37 @@ class SimpleSequentialChain(Chain):
)
return values
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = chain.run(_input)
_input = chain.run(_input, callbacks=_run_manager.get_child())
if self.strip_outputs:
_input = _input.strip()
self.callback_manager.on_text(
_run_manager.on_text(
_input, color=color_mapping[str(i)], end="\n", verbose=self.verbose
)
return {self.output_key: _input}
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = await chain.arun(_input)
_input = await chain.arun(_input, callbacks=callbacks)
if self.strip_outputs:
_input = _input.strip()
if self.callback_manager.is_async:
await self.callback_manager.on_text(
_input, color=color_mapping[str(i)], end="\n", verbose=self.verbose
)
else:
self.callback_manager.on_text(
await _run_manager.on_text(
_input, color=color_mapping[str(i)], end="\n", verbose=self.verbose
)
return {self.output_key: _input}

View File

@ -1,15 +1,17 @@
"""Chain for interacting with SQL Database."""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional
from pydantic import Extra, Field
from pydantic import Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.sql_database import SQLDatabase
@ -21,15 +23,16 @@ class SQLDatabaseChain(Chain):
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
db = SQLDatabase(...)
db_chain = SQLDatabaseChain(llm=OpenAI(), database=db)
db_chain = SQLDatabaseChain.from_llm(OpenAI(), db)
"""
llm: BaseLanguageModel
"""LLM wrapper to use."""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
database: SQLDatabase = Field(exclude=True)
"""SQL Database to connect to."""
prompt: Optional[BasePromptTemplate] = None
"""Prompt to use to translate natural language to SQL."""
"""[Deprecated] Prompt to use to translate natural language to SQL."""
top_k: int = 5
"""Number of results to return from the query"""
input_key: str = "query" #: :meta private:
@ -45,6 +48,22 @@ class SQLDatabaseChain(Chain):
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an SQLDatabaseChain with an llm is deprecated. "
"Please instantiate with llm_chain argument or using the from_llm "
"class method."
)
if "llm_chain" not in values and values["llm"] is not None:
database = values["database"]
prompt = values.get("prompt") or SQL_PROMPTS.get(
database.dialect, PROMPT
)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
@ -64,11 +83,14 @@ class SQLDatabaseChain(Chain):
else:
return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
prompt = self.prompt or SQL_PROMPTS.get(self.database.dialect, PROMPT)
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
input_text = f"{inputs[self.input_key]}\nSQLQuery:"
self.callback_manager.on_text(input_text, verbose=self.verbose)
_run_manager.on_text(input_text, verbose=self.verbose)
# If not present, then defaults to None which is all tables.
table_names_to_use = inputs.get("table_names_to_use")
table_info = self.database.get_table_info(table_names=table_names_to_use)
@ -80,24 +102,26 @@ class SQLDatabaseChain(Chain):
"stop": ["\nSQLResult:"],
}
intermediate_steps = []
sql_cmd = llm_chain.predict(**llm_inputs)
sql_cmd = self.llm_chain.predict(
callbacks=_run_manager.get_child(), **llm_inputs
)
intermediate_steps.append(sql_cmd)
self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
_run_manager.on_text(sql_cmd, color="green", verbose=self.verbose)
result = self.database.run(sql_cmd)
intermediate_steps.append(result)
self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose)
self.callback_manager.on_text(result, color="yellow", verbose=self.verbose)
_run_manager.on_text("\nSQLResult: ", verbose=self.verbose)
_run_manager.on_text(result, color="yellow", verbose=self.verbose)
# If return direct, we just set the final result equal to the sql query
if self.return_direct:
final_result = result
else:
self.callback_manager.on_text("\nAnswer:", verbose=self.verbose)
_run_manager.on_text("\nAnswer:", verbose=self.verbose)
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
llm_inputs["input"] = input_text
final_result = llm_chain.predict(**llm_inputs)
self.callback_manager.on_text(
final_result, color="green", verbose=self.verbose
final_result = self.llm_chain.predict(
callbacks=_run_manager.get_child(), **llm_inputs
)
_run_manager.on_text(final_result, color="green", verbose=self.verbose)
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result["intermediate_steps"] = intermediate_steps
@ -107,6 +131,18 @@ class SQLDatabaseChain(Chain):
def _chain_type(self) -> str:
return "sql_database_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> SQLDatabaseChain:
prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT)
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, database=db, **kwargs)
class SQLDatabaseSequentialChain(Chain):
"""Chain for querying SQL database that is a sequential chain.
@ -118,6 +154,10 @@ class SQLDatabaseSequentialChain(Chain):
This is useful in cases where the number of tables in the database is large.
"""
decider_chain: LLMChain
sql_chain: SQLDatabaseChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
return_intermediate_steps: bool = False
@classmethod
@ -138,11 +178,6 @@ class SQLDatabaseSequentialChain(Chain):
)
return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs)
decider_chain: LLMChain
sql_chain: SQLDatabaseChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the singular input key.
@ -162,25 +197,32 @@ class SQLDatabaseSequentialChain(Chain):
else:
return [self.output_key, "intermediate_steps"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_table_names = self.sql_chain.database.get_usable_table_names()
table_names = ", ".join(_table_names)
llm_inputs = {
"query": inputs[self.input_key],
"table_names": table_names,
}
table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs)
self.callback_manager.on_text(
"Table names to use:", end="\n", verbose=self.verbose
table_names_to_use = self.decider_chain.predict_and_parse(
callbacks=_run_manager.get_child(), **llm_inputs
)
self.callback_manager.on_text(
_run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(table_names_to_use), color="yellow", verbose=self.verbose
)
new_inputs = {
self.sql_chain.input_key: inputs[self.input_key],
"table_names_to_use": table_names_to_use,
}
return self.sql_chain(new_inputs, return_only_outputs=True)
return self.sql_chain(
new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True
)
@property
def _chain_type(self) -> str:

View File

@ -1,6 +1,7 @@
"""Load summarizing chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
@ -8,7 +9,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):

View File

@ -1,6 +1,7 @@
"""Chain that runs an arbitrary python function."""
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
@ -35,5 +36,9 @@ class TransformChain(Chain):
"""
return self.output_variables
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform(inputs)

View File

@ -2,6 +2,10 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
@ -85,7 +89,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
) # trim off the trailing ' ' that might come from the "Assistant: "
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
@ -98,9 +105,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
self.callback_manager.on_llm_new_token(
if run_manager:
run_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
response = self.client.completion(**params)
@ -109,7 +116,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
@ -122,15 +132,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
async for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
if run_manager:
await run_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
response = await self.client.acompletion(**params)

View File

@ -1,21 +1,30 @@
import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Dict, List, Optional
from pydantic import Extra, Field, validator
from pydantic import Extra, Field, root_validator
import langchain
from langchain.callbacks import get_callback_manager
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import (
AIMessage,
BaseLanguageModel,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
PromptValue,
get_buffer_string,
)
@ -26,7 +35,19 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
callbacks: Callbacks = None
callback_manager: Optional[BaseCallbackManager] = None
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
"""Configuration for this pydantic object."""
@ -34,98 +55,130 @@ class BaseChatModel(BaseLanguageModel, ABC):
extra = Extra.forbid
arbitrary_types_allowed = True
@validator("callback_manager", pre=True, always=True)
def set_callback_manager(
cls, callback_manager: Optional[BaseCallbackManager]
) -> BaseCallbackManager:
"""If callback manager is None, set it.
This allows users to pass in None as callback manager, which is a nice UX.
"""
return callback_manager or get_callback_manager()
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def generate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
results = [self._generate(m, stop=stop) for m in messages]
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
return LLMResult(generations=generations, llm_output=llm_output)
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
return output
async def agenerate(
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
results = await asyncio.gather(
*[self._agenerate(m, stop=stop) for m in messages]
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
return LLMResult(generations=generations, llm_output=llm_output)
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
return output
def generate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
prompt_strings = [p.to_string() for p in prompts]
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
)
try:
output = self.generate(prompt_messages, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
async def agenerate_prompt(
self, prompts: List[PromptValue], stop: Optional[List[str]] = None
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
prompt_strings = [p.to_string() for p in prompts]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
)
else:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose
)
try:
output = await self.agenerate(prompt_messages, stop=stop)
except (KeyboardInterrupt, Exception) as e:
if self.callback_manager.is_async:
await self.callback_manager.on_llm_error(e, verbose=self.verbose)
else:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
if self.callback_manager.is_async:
await self.callback_manager.on_llm_end(output, verbose=self.verbose)
else:
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
@abstractmethod
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
@abstractmethod
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
def __call__(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> BaseMessage:
return self._generate(messages, stop=stop).generations[0].message
generation = self.generate(
[messages], stop=stop, callbacks=callbacks
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
result = self([HumanMessage(content=message)], stop=stop)
@ -134,15 +187,21 @@ class BaseChatModel(BaseLanguageModel, ABC):
class SimpleChatModel(BaseChatModel):
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
output_str = self._call(messages, stop=stop)
output_str = self._call(messages, stop=stop, run_manager=run_manager)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Simpler interface."""

View File

@ -14,6 +14,10 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
AIMessage,
@ -242,7 +246,10 @@ class ChatOpenAI(BaseChatModel):
return {"token_usage": overall_token_usage, "model_name": self.model_name}
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
@ -255,10 +262,8 @@ class ChatOpenAI(BaseChatModel):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{"content": inner_completion, "role": role}
)
@ -287,7 +292,10 @@ class ChatOpenAI(BaseChatModel):
return ChatResult(generations=generations, llm_output=llm_output)
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
if self.streaming:
@ -300,16 +308,8 @@ class ChatOpenAI(BaseChatModel):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
token,
verbose=self.verbose,
)
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{"content": inner_completion, "role": role}
)

View File

@ -2,6 +2,10 @@
import datetime
from typing import List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, ChatResult
@ -33,13 +37,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
return_pl_id: Optional[bool] = False
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
request_start_time = datetime.datetime.now().timestamp()
generated_responses = super()._generate(messages, stop)
generated_responses = super()._generate(messages, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
message_dicts, params = super()._create_message_dicts(messages, stop)
for i, generation in enumerate(generated_responses.generations):
@ -67,13 +74,16 @@ class PromptLayerChatOpenAI(ChatOpenAI):
return generated_responses
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
from promptlayer.utils import get_api_key, promptlayer_api_request_async
request_start_time = datetime.datetime.now().timestamp()
generated_responses = await super()._agenerate(messages, stop)
generated_responses = await super()._agenerate(messages, stop, run_manager)
request_end_time = datetime.datetime.now().timestamp()
message_dicts, params = super()._create_message_dicts(messages, stop)
for i, generation in enumerate(generated_responses.generations):

View File

@ -1,6 +1,7 @@
"""A chain for evaluating ReAct style agents."""
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chat_models import ChatOpenAI
@ -94,7 +95,11 @@ Tool output: {output}"""
return ["score", "reasoning"]
return ["score"]
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
raw_output = self.eval_chain.run(
{"tool_descriptions": self._tools_description, **inputs}
)

View File

@ -1,8 +1,11 @@
"""BabyAGI agent."""
from collections import deque
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.experimental.autonomous_agents.baby_agi.task_creation import (
TaskCreationChain,
@ -13,7 +16,6 @@ from langchain.experimental.autonomous_agents.baby_agi.task_execution import (
from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import (
TaskPrioritizationChain,
)
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore
@ -112,7 +114,11 @@ class BabyAGI(Chain, BaseModel):
objective=objective, context="\n".join(context), task=task
)
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the agent."""
objective = inputs["objective"]
first_task = inputs.get("first_task", "Make a todo list")

View File

@ -1,5 +1,5 @@
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language import BaseLanguageModel
class TaskCreationChain(LLMChain):

View File

@ -1,5 +1,5 @@
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language import BaseLanguageModel
class TaskExecutionChain(LLMChain):

View File

@ -1,5 +1,5 @@
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.base_language import BaseLanguageModel
class TaskPrioritizationChain(LLMChain):

View File

@ -5,9 +5,9 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.experimental.generative_agents.memory import GenerativeAgentMemory
from langchain.prompts import PromptTemplate
from langchain.schema import BaseLanguageModel
class GenerativeAgent(BaseModel):

View File

@ -3,9 +3,10 @@ import re
from typing import Any, Dict, List, Optional
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.prompts import PromptTemplate
from langchain.retrievers import TimeWeightedVectorStoreRetriever
from langchain.schema import BaseLanguageModel, BaseMemory, Document
from langchain.schema import BaseMemory, Document
logger = logging.getLogger(__name__)

View File

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
@ -106,7 +107,12 @@ class AI21(LLM):
"""Return type of llm."""
return "ai21"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to AI21's complete endpoint.
Args:

Some files were not shown because too many files have changed in this diff Show More