diff --git a/docs/ecosystem/aim_tracking.ipynb b/docs/ecosystem/aim_tracking.ipynb index fa9755f9..c7b1cc62 100644 --- a/docs/ecosystem/aim_tracking.ipynb +++ b/docs/ecosystem/aim_tracking.ipynb @@ -61,7 +61,6 @@ "from datetime import datetime\n", "\n", "from langchain.llms import OpenAI\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks import AimCallbackHandler, StdOutCallbackHandler" ] }, @@ -109,8 +108,8 @@ " experiment_name=\"scenario 1: OpenAI LLM\",\n", ")\n", "\n", - "manager = CallbackManager([StdOutCallbackHandler(), aim_callback])\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)" + "callbacks = [StdOutCallbackHandler(), aim_callback]\n", + "llm = OpenAI(temperature=0, callbacks=callbacks)" ] }, { @@ -177,7 +176,7 @@ "Title: {title}\n", "Playwright: This is a synopsis for the above play:\"\"\"\n", "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n", "\n", "test_prompts = [\n", " {\"title\": \"documentary about good video games that push the boundary of game design\"},\n", @@ -249,13 +248,12 @@ ], "source": [ "# scenario 3 - Agent with Tools\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", - " callback_manager=manager,\n", - " verbose=True,\n", + " callbacks=callbacks,\n", ")\n", "agent.run(\n", " \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n", diff --git a/docs/ecosystem/clearml_tracking.ipynb b/docs/ecosystem/clearml_tracking.ipynb index 20b118c6..0fb33c2d 100644 --- a/docs/ecosystem/clearml_tracking.ipynb +++ b/docs/ecosystem/clearml_tracking.ipynb @@ -79,7 +79,6 @@ "source": [ "from datetime import datetime\n", "from langchain.callbacks import ClearMLCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI\n", "\n", "# Setup and use the ClearML Callback\n", @@ -93,9 +92,9 @@ " complexity_metrics=True,\n", " stream_logs=True\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), clearml_callback])\n", + "callbacks = [StdOutCallbackHandler(), clearml_callback]\n", "# Get the OpenAI model ready to go\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)" + "llm = OpenAI(temperature=0, callbacks=callbacks)" ] }, { @@ -523,13 +522,12 @@ "from langchain.agents import AgentType\n", "\n", "# SCENARIO 2 - Agent with Tools\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", - " callback_manager=manager,\n", - " verbose=True,\n", + " callbacks=callbacks,\n", ")\n", "agent.run(\n", " \"Who is the wife of the person who sang summer of 69?\"\n", diff --git a/docs/ecosystem/comet_tracking.ipynb b/docs/ecosystem/comet_tracking.ipynb index a32646c3..4271b2ef 100644 --- a/docs/ecosystem/comet_tracking.ipynb +++ b/docs/ecosystem/comet_tracking.ipynb @@ -121,7 +121,6 @@ "from datetime import datetime\n", "\n", "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI\n", "\n", "comet_callback = CometCallbackHandler(\n", @@ -131,8 +130,8 @@ " tags=[\"llm\"],\n", " visualizations=[\"dep\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9, callbacks=callbacks, verbose=True)\n", "\n", "llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\", \"Tell me a fact\"] * 3)\n", "print(\"LLM result\", llm_result)\n", @@ -153,7 +152,6 @@ "outputs": [], "source": [ "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", "from langchain.prompts import PromptTemplate\n", @@ -164,15 +162,14 @@ " stream_logs=True,\n", " tags=[\"synopsis-chain\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9, callbacks=callbacks)\n", "\n", "template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n", "Title: {title}\n", "Playwright: This is a synopsis for the above play:\"\"\"\n", "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n", "\n", "test_prompts = [{\"title\": \"Documentary about Bigfoot in Paris\"}]\n", "print(synopsis_chain.apply(test_prompts))\n", @@ -194,7 +191,6 @@ "source": [ "from langchain.agents import initialize_agent, load_tools\n", "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI\n", "\n", "comet_callback = CometCallbackHandler(\n", @@ -203,15 +199,15 @@ " stream_logs=True,\n", " tags=[\"agent\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9, callbacks=callbacks)\n", "\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callbacks=callbacks)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=\"zero-shot-react-description\",\n", - " callback_manager=manager,\n", + " callbacks=callbacks,\n", " verbose=True,\n", ")\n", "agent.run(\n", @@ -255,7 +251,6 @@ "from rouge_score import rouge_scorer\n", "\n", "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", "from langchain.prompts import PromptTemplate\n", @@ -298,10 +293,10 @@ " tags=[\"custom_metrics\"],\n", " custom_metrics=rouge_score.compute_metric,\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", - "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "callbacks = [StdOutCallbackHandler(), comet_callback]\n", + "llm = OpenAI(temperature=0.9)\n", "\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)\n", "\n", "test_prompts = [\n", " {\n", @@ -323,7 +318,7 @@ " \"\"\"\n", " }\n", "]\n", - "print(synopsis_chain.apply(test_prompts))\n", + "print(synopsis_chain.apply(test_prompts, callbacks=callbacks))\n", "comet_callback.flush_tracker(synopsis_chain, finish=True)" ] } diff --git a/docs/ecosystem/gpt4all.md b/docs/ecosystem/gpt4all.md index 36422eb5..7dc5a025 100644 --- a/docs/ecosystem/gpt4all.md +++ b/docs/ecosystem/gpt4all.md @@ -3,6 +3,7 @@ This page covers how to use the `GPT4All` wrapper within LangChain. The tutorial is divided into two parts: installation and setup, followed by usage with an example. ## Installation and Setup + - Install the Python package with `pip install pyllamacpp` - Download a [GPT4All model](https://github.com/nomic-ai/pyllamacpp#supported-model) and place it in your desired directory @@ -28,16 +29,16 @@ To stream the model's predictions, add in a CallbackManager. ```python from langchain.llms import GPT4All -from langchain.callbacks.base import CallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + # There are many CallbackHandlers supported, such as # from langchain.callbacks.streamlit import StreamlitCallbackHandler -callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) -model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8, callback_handler=callback_handler, verbose=True) +callbacks = [StreamingStdOutCallbackHandler()] +model = GPT4All(model="./models/gpt4all-model.bin", n_ctx=512, n_threads=8) # Generate text. Tokens are streamed through the callback manager. -model("Once upon a time, ") +model("Once upon a time, ", callbacks=callbacks) ``` ## Model File diff --git a/docs/ecosystem/wandb_tracking.ipynb b/docs/ecosystem/wandb_tracking.ipynb index 9ead0230..78e4fb6a 100644 --- a/docs/ecosystem/wandb_tracking.ipynb +++ b/docs/ecosystem/wandb_tracking.ipynb @@ -50,7 +50,6 @@ "source": [ "from datetime import datetime\n", "from langchain.callbacks import WandbCallbackHandler, StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.llms import OpenAI" ] }, @@ -196,8 +195,8 @@ " name=\"llm\",\n", " tags=[\"test\"],\n", ")\n", - "manager = CallbackManager([StdOutCallbackHandler(), wandb_callback])\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)" + "callbacks = [StdOutCallbackHandler(), wandb_callback]\n", + "llm = OpenAI(temperature=0, callbacks=callbacks)" ] }, { @@ -484,7 +483,7 @@ "Title: {title}\n", "Playwright: This is a synopsis for the above play:\"\"\"\n", "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", - "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callbacks=callbacks)\n", "\n", "test_prompts = [\n", " {\n", @@ -577,16 +576,15 @@ ], "source": [ "# SCENARIO 3 - Agent with Tools\n", - "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm)\n", "agent = initialize_agent(\n", " tools,\n", " llm,\n", " agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,\n", - " callback_manager=manager,\n", - " verbose=True,\n", ")\n", "agent.run(\n", - " \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n", + " \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\",\n", + " callbacks=callbacks,\n", ")\n", "wandb_callback.flush_tracker(agent, reset=False, finish=True)" ] diff --git a/docs/index.rst b/docs/index.rst index 04e3abcb..0533d78c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,6 +44,8 @@ These modules are, in increasing order of complexity: - `Agents <./modules/agents.html>`_: Agents involve an LLM making decisions about which Actions to take, taking that Action, seeing an Observation, and repeating that until done. LangChain provides a standard interface for agents, a selection of agents to choose from, and examples of end to end agents. +- `Callbacks <./modules/callbacks/getting_started.html>`_: It can be difficult to track all that occurs inside a chain or agent - callbacks help add a level of observability and introspection. + .. toctree:: :maxdepth: 1 @@ -57,6 +59,7 @@ These modules are, in increasing order of complexity: ./modules/memory.md ./modules/chains.md ./modules/agents.md + ./modules/callbacks/getting_started.ipynb Use Cases ---------- diff --git a/docs/modules/agents/agent_executors/examples/async_agent.ipynb b/docs/modules/agents/agent_executors/examples/async_agent.ipynb index 3ef46cb4..cc9a92b3 100644 --- a/docs/modules/agents/agent_executors/examples/async_agent.ipynb +++ b/docs/modules/agents/agent_executors/examples/async_agent.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 7, "id": "da5df06c-af6f-4572-b9f5-0ab971c16487", "metadata": { "tags": [] @@ -42,7 +42,6 @@ "from langchain.agents import AgentType\n", "from langchain.llms import OpenAI\n", "from langchain.callbacks.stdout import StdOutCallbackHandler\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.tracers import LangChainTracer\n", "from aiohttp import ClientSession\n", "\n", @@ -57,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "id": "fd4c294e-b1d6-44b8-b32e-2765c017e503", "metadata": { "tags": [] @@ -73,16 +72,15 @@ "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", "Action: Search\n", "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Rafael Nadal's age\n", + "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", "Action: Search\n", "Action Input: \"Rafael Nadal age\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 36 raised to the 0.334 power\n", + "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n", "Action: Calculator\n", "Action Input: 36^0.334\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n", - "\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", "\n", @@ -93,18 +91,17 @@ "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", "Action: Search\n", "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mJason Sudeikis\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Jason Sudeikis' age\n", + "Observation: \u001b[33;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", "Action: Search\n", - "Action Input: \"Jason Sudeikis age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m47 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 47 raised to the 0.23 power\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", "Action: Calculator\n", - "Action Input: 47^0.23\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.4242784855673896\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Jason Sudeikis, Olivia Wilde's boyfriend, is 47 years old and his age raised to the 0.23 power is 2.4242784855673896.\u001b[0m\n", + "Action Input: 29^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", + "Final Answer: Harry Styles, Olivia Wilde's boyfriend, is 29 years old and his age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", @@ -113,17 +110,17 @@ "\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", "Action: Search\n", "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mMax Verstappen\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Max Verstappen's age\n", + "Observation: \u001b[33;1m\u001b[1;3mMichael Schumacher (top left) and Lewis Hamilton (top right) have each won the championship a record seven times during their careers, while Sebastian Vettel ( ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", "Action: Search\n", - "Action Input: \"Max Verstappen Age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m25 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.23 power\n", + "Action Input: \"Michael Schumacher age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m54 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.23 power\n", "Action: Calculator\n", - "Action Input: 25^0.23\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.84599359907945\u001b[0m\n", + "Action Input: 54^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.502940725307012\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Max Verstappen, 25 years old, raised to the 0.23 power is 1.84599359907945.\u001b[0m\n", + "Final Answer: Michael Schumacher, aged 54, raised to the 0.23 power is 2.502940725307012.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", @@ -132,18 +129,17 @@ "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n", "Action: Search\n", "Action Input: \"US Open women's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu defeated Serena Williams in the final, 6–3, 7–5 to win the women's singles tennis title at the 2019 US Open. It was her first major title, and she became the first Canadian, as well as the first player born in the 2000s, to win a major singles title.\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Bianca Andreescu's age.\n", + "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out her age\n", "Action: Search\n", "Action Input: \"Bianca Andreescu age\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3m22 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the age of Bianca Andreescu and can calculate her age raised to the 0.34 power.\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.34 power\n", "Action: Calculator\n", "Action Input: 22^0.34\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", - "Final Answer: Bianca Andreescu won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.8603798598506933.\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: Bianca Andreescu, aged 22, won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.86.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", @@ -160,35 +156,32 @@ "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n", "Action: Calculator\n", "Action Input: 53^0.19\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\n", - "\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", "Final Answer: Jay-Z is Beyonce's husband and his age raised to the 0.19 power is 2.12624064206896.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "Serial executed in 65.11 seconds.\n" + "Serial executed in 52.47 seconds.\n" ] } ], "source": [ - "def generate_serially():\n", - " for q in questions:\n", - " llm = OpenAI(temperature=0)\n", - " tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", - " agent = initialize_agent(\n", - " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", - " )\n", - " agent.run(q)\n", + "llm = OpenAI(temperature=0)\n", + "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")\n", "\n", "s = time.perf_counter()\n", - "generate_serially()\n", + "for q in questions:\n", + " agent.run(q)\n", "elapsed = time.perf_counter() - s\n", "print(f\"Serial executed in {elapsed:0.2f} seconds.\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "id": "076d7b85-45ec-465d-8b31-c2ad119c3438", "metadata": { "tags": [] @@ -202,6 +195,9 @@ "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\n", @@ -210,182 +206,94 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", - "Action: Search\n", - "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Beyonce's husband is and then calculate his age raised to the 0.19 power.\n", - "Action: Search\n", - "Action Input: \"Who is Beyonce's husband?\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mJay-Z\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", - "Action: Search\n", - "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open women's final in 2019 and then calculate her age raised to the 0.34 power.\n", "Action: Search\n", "Action Input: \"US Open women's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mJason Sudeikis\u001b[0m\n", - "Thought:\n", - "Observation: \u001b[33;1m\u001b[1;3mMax Verstappen\u001b[0m\n", - "Thought:\n", - "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu defeated Serena Williams in the final, 6–3, 7–5 to win the women's singles tennis title at the 2019 US Open. It was her first major title, and she became the first Canadian, as well as the first player born in the 2000s, to win a major singles title.\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Jason Sudeikis' age\n", - "Action: Search\n", - "Action Input: \"Jason Sudeikis age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Jay-Z's age\n", - "Action: Search\n", - "Action Input: \"How old is Jay-Z?\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m53 years\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mBianca Andreescu\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", "Action: Search\n", "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Beyonce's husband is and then calculate his age raised to the 0.19 power.\n", + "Action: Search\n", + "Action Input: \"Who is Beyonce's husband?\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mOlivia Wilde started dating Harry Styles after ending her years-long engagement to Jason Sudeikis — see their relationship timeline.\u001b[0m\n", "Thought:\n", - "Observation: \u001b[33;1m\u001b[1;3m47 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Max Verstappen's age\n", + "Observation: \u001b[33;1m\u001b[1;3mJay-Z\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[33;1m\u001b[1;3mMichael Schumacher (top left) and Lewis Hamilton (top right) have each won the championship a record seven times during their careers, while Sebastian Vettel ( ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out her age\n", "Action: Search\n", - "Action Input: \"Max Verstappen Age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m25 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Bianca Andreescu's age.\n", + "Action Input: \"Bianca Andreescu age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Jay-Z's age\n", "Action: Search\n", - "Action Input: \"Bianca Andreescu age\"\u001b[0m\n", + "Action Input: \"How old is Jay-Z?\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m53 years\u001b[0m\n", + "Thought:\n", "Observation: \u001b[33;1m\u001b[1;3m22 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n", - "Action: Calculator\n", - "Action Input: 53^0.19\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", "Action: Search\n", - "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 47 raised to the 0.23 power\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.34 power\n", "Action: Calculator\n", - "Action Input: 47^0.23\u001b[0m\n", + "Action Input: 22^0.34\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 53 raised to the 0.19 power\n", + "Action: Calculator\n", + "Action Input: 53^0.19\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action: Calculator\n", + "Action Input: 29^0.23\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Michael Schumacher age\"\u001b[0m\n", + "Observation: \n", "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 25 raised to the 0.23 power\n", - "Action: Calculator\n", - "Action Input: 25^0.23\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the age of Bianca Andreescu and can calculate her age raised to the 0.34 power.\n", - "Action: Calculator\n", - "Action Input: 22^0.34\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.84599359907945\u001b[0m\n", + "Thought:\u001b[33;1m\u001b[1;3m54 years\u001b[0m\n", "Thought:\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.4242784855673896\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now need to calculate his age raised to the 0.334 power\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.12624064206896\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", "Action: Calculator\n", - "Action Input: 36^0.334\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.8603798598506933\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Jay-Z is Beyonce's husband and his age raised to the 0.19 power is 2.12624064206896.\u001b[0m\n", - "\n", + "Action Input: 36^0.334\u001b[0m\u001b[32;1m\u001b[1;3m I now need to calculate the age raised to the 0.23 power\n", + "Action: Calculator\n", + "Action Input: 54^0.23\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Max Verstappen, 25 years old, raised to the 0.23 power is 1.84599359907945.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Jason Sudeikis, Olivia Wilde's boyfriend, is 47 years old and his age raised to the 0.23 power is 2.4242784855673896.\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.502940725307012\u001b[0m\n", + "Thought:\n", + "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I now know the final answer.\n", - "Final Answer: Bianca Andreescu won the US Open women's final in 2019 and her age raised to the 0.34 power is 2.8603798598506933.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n", - "Concurrent executed in 12.38 seconds.\n" + "Concurrent executed in 14.49 seconds.\n" ] } ], "source": [ - "async def generate_concurrently():\n", - " agents = []\n", - " # To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n", - " # but you must manually close the client session at the end of your program/event loop\n", - " aiosession = ClientSession()\n", - " for _ in questions:\n", - " manager = CallbackManager([StdOutCallbackHandler()])\n", - " llm = OpenAI(temperature=0, callback_manager=manager)\n", - " async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n", - " agents.append(\n", - " initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n", - " )\n", - " tasks = [async_agent.arun(q) for async_agent, q in zip(agents, questions)]\n", - " await asyncio.gather(*tasks)\n", - " await aiosession.close()\n", + "llm = OpenAI(temperature=0)\n", + "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")\n", "\n", "s = time.perf_counter()\n", - "# If running this outside of Jupyter, use asyncio.run(generate_concurrently())\n", - "await generate_concurrently()\n", + "# If running this outside of Jupyter, use asyncio.run or loop.run_until_complete\n", + "tasks = [agent.arun(q) for q in questions]\n", + "await asyncio.gather(*tasks)\n", "elapsed = time.perf_counter() - s\n", "print(f\"Concurrent executed in {elapsed:0.2f} seconds.\")" ] - }, - { - "cell_type": "markdown", - "id": "97ef285c-4a43-4a4e-9698-cd52a1bc56c9", - "metadata": {}, - "source": [ - "## Using Tracing with Asynchronous Agents\n", - "\n", - "To use tracing with async agents, you must pass in a custom `CallbackManager` with `LangChainTracer` to each agent running asynchronously. This way, you avoid collisions while the trace is being collected." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "44bda05a-d33e-4e91-9a71-a0f3f96aae95", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", - "Action: Search\n", - "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to find out Rafael Nadal's age\n", - "Action: Search\n", - "Action Input: \"Rafael Nadal age\"\u001b[0m\n", - "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 36 raised to the 0.334 power\n", - "Action: Calculator\n", - "Action Input: 36^0.334\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", - "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", - "\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - } - ], - "source": [ - "# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n", - "# but you must manually close the client session at the end of your program/event loop\n", - "aiosession = ClientSession()\n", - "tracer = LangChainTracer()\n", - "tracer.load_default_session()\n", - "manager = CallbackManager([StdOutCallbackHandler(), tracer])\n", - "\n", - "# Pass the manager into the llm if you want llm calls traced.\n", - "llm = OpenAI(temperature=0, callback_manager=manager)\n", - "\n", - "async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession)\n", - "async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n", - "await async_agent.arun(questions[0])\n", - "await aiosession.close()" - ] } ], "metadata": { @@ -404,7 +312,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb b/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb index c81cc19e..6bbb4ad4 100644 --- a/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb +++ b/docs/modules/agents/agents/custom_agent_with_tool_retrieval.ipynb @@ -373,6 +373,7 @@ "metadata": {}, "outputs": [], "source": [ + "tools = get_tools(\"whats the weather?\")\n", "tool_names = [tool.name for tool in tools]\n", "agent = LLMSingleActionAgent(\n", " llm_chain=llm_chain, \n", diff --git a/docs/modules/agents/tools/examples/arxiv.ipynb b/docs/modules/agents/tools/examples/arxiv.ipynb index 0df922e9..38027d3c 100644 --- a/docs/modules/agents/tools/examples/arxiv.ipynb +++ b/docs/modules/agents/tools/examples/arxiv.ipynb @@ -75,6 +75,7 @@ } ], "source": [ + "\n", "arxiv = ArxivAPIWrapper()\n", "docs = arxiv.run(\"1605.08386\")\n", "docs" @@ -163,7 +164,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.10.6" } }, "nbformat": 4, diff --git a/docs/modules/agents/tools/examples/gradio_tools.ipynb b/docs/modules/agents/tools/examples/gradio_tools.ipynb index a1a8c2ca..d4a18918 100644 --- a/docs/modules/agents/tools/examples/gradio_tools.ipynb +++ b/docs/modules/agents/tools/examples/gradio_tools.ipynb @@ -69,7 +69,8 @@ } ], "source": [ - "StableDiffusionTool().langchain.run(\"Please create a photo of a dog riding a skateboard\")" + "local_file_path = StableDiffusionTool().langchain.run(\"Please create a photo of a dog riding a skateboard\")\n", + "local_file_path" ] }, { @@ -89,7 +90,7 @@ "metadata": {}, "outputs": [], "source": [ - "im = Image.open(\"/Users/harrisonchase/workplace/langchain/docs/modules/agents/tools/examples/b61c1dd9-47e2-46f1-a47c-20d27640993d/tmp4ap48vnm.jpg\")" + "im = Image.open(local_file_path)" ] }, { diff --git a/docs/modules/agents/tools/examples/python.ipynb b/docs/modules/agents/tools/examples/python.ipynb index db2824f2..a1428545 100644 --- a/docs/modules/agents/tools/examples/python.ipynb +++ b/docs/modules/agents/tools/examples/python.ipynb @@ -19,6 +19,7 @@ "metadata": {}, "outputs": [], "source": [ + "from langchain.agents import Tool\n", "from langchain.utilities import PythonREPL" ] }, @@ -59,7 +60,14 @@ "id": "54fc1f03", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# You can create the tool to pass to an agent\n", + "repl_tool = Tool(\n", + " name=\"python_repl\",\n", + " description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n", + " func=python_repl\n", + ")" + ] } ], "metadata": { diff --git a/docs/modules/agents/tools/examples/serpapi.ipynb b/docs/modules/agents/tools/examples/serpapi.ipynb index c77821ca..c4ad0a6b 100644 --- a/docs/modules/agents/tools/examples/serpapi.ipynb +++ b/docs/modules/agents/tools/examples/serpapi.ipynb @@ -102,7 +102,15 @@ "id": "e0a1dc1c", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from langchain.agents import Tool\n", + "# You can create the tool to pass to an agent\n", + "repl_tool = Tool(\n", + " name=\"python_repl\",\n", + " description=\"A Python shell. Use this to execute python commands. Input should be a valid python command. If you want to see the output of a value, you should print it out with `print(...)`.\",\n", + " func=search.run,\n", + ")" + ] } ], "metadata": { diff --git a/docs/modules/callbacks/getting_started.ipynb b/docs/modules/callbacks/getting_started.ipynb index cc74a3e8..109907bd 100644 --- a/docs/modules/callbacks/getting_started.ipynb +++ b/docs/modules/callbacks/getting_started.ipynb @@ -17,33 +17,7 @@ "source": [ "LangChain provides a callback system that allows you to hook into the various stages of your LLM application. This is useful for logging, [monitoring](https://python.langchain.com/en/latest/tracing.html), [streaming](https://python.langchain.com/en/latest/modules/models/llms/examples/streaming_llm.html), and other tasks.\n", "\n", - "You can subscribe to these events by using the `callback_manager` argument available throughout the API. A `CallbackManager` is an object that manages a list of `CallbackHandlers`. The `CallbackManager` will call the appropriate method on each handler when the event is triggered." - ] - }, - { - "cell_type": "markdown", - "id": "fdb72e8d-a02a-474d-96bf-f5759432afc8", - "metadata": { - "tags": [] - }, - "source": [ - "```python\n", - "class CallbackManager(BaseCallbackHandler):\n", - " \"\"\"Base callback manager that can be used to handle callbacks from LangChain.\"\"\"\n", - "\n", - " def add_handler(self, callback: BaseCallbackHandler) -> None:\n", - " \"\"\"Add a handler to the callback manager.\"\"\"\n", - "\n", - " def remove_handler(self, handler: BaseCallbackHandler) -> None:\n", - " \"\"\"Remove a handler from the callback manager.\"\"\"\n", - "\n", - " def set_handler(self, handler: BaseCallbackHandler) -> None:\n", - " \"\"\"Set handler as the only handler on the callback manager.\"\"\"\n", - " self.set_handlers([handler])\n", - "\n", - " def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:\n", - " \"\"\"Set handlers as the only handlers on the callback manager.\"\"\"\n", - "```" + "You can subscribe to these events by using the `callbacks` argument available throughout the API. This argument list of handler objects, which are expected to implement one or more of the methods described in the API docs." ] }, { @@ -62,88 +36,97 @@ }, "source": [ "```python\n", - "class BaseCallbackHandler(ABC):\n", + "class BaseCallbackHandler:\n", " \"\"\"Base callback handler that can be used to handle callbacks from langchain.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_start(\n", " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when LLM starts running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n", " \"\"\"Run on new LLM token. Only available when streaming is enabled.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:\n", " \"\"\"Run when LLM ends running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_llm_error(\n", " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when LLM errors.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_chain_start(\n", " self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when chain starts running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:\n", " \"\"\"Run when chain ends running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_chain_error(\n", " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when chain errors.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_tool_start(\n", " self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when tool starts running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_tool_end(self, output: str, **kwargs: Any) -> Any:\n", " \"\"\"Run when tool ends running.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_tool_error(\n", " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", " ) -> Any:\n", " \"\"\"Run when tool errors.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_text(self, text: str, **kwargs: Any) -> Any:\n", " \"\"\"Run on arbitrary text.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n", " \"\"\"Run on agent action.\"\"\"\n", "\n", - " @abstractmethod\n", " def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:\n", " \"\"\"Run on agent end.\"\"\"\n", "```" ] }, + { + "cell_type": "markdown", + "id": "cbccd7d1", + "metadata": {}, + "source": [ + "## How to use callbacks\n", + "\n", + "The `callbacks` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) in two different places:\n", + "\n", + "- **Constructor callbacks**: defined in the constructor, eg. `LLMChain(callbacks=[handler])`, which will be used for all calls made on that object, and will be scoped to that object only, eg. if you pass a handler to the `LLMChain` constructor, it will not be used by the Model attached to that chain.\n", + "- **Request callbacks**: defined in the `call()`/`run()`/`apply()` methods used for issuing a request, eg. `chain.call(inputs, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).\n", + "\n", + "The `verbose` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) as a constructor argument, eg. `LLMChain(verbose=True)`, and it is equivalent to passing a `ConsoleCallbackHandler` to the `callbacks` argument of that object and all child objects. This is useful for debugging, as it will log all events to the console.\n", + "\n", + "### When do you want to use each of these?\n", + "\n", + "- Constructor callbacks are most useful for use cases such as logging, monitoring, etc., which are _not specific to a single request_, but rather to the entire chain. For example, if you want to log all the requests made to an LLMChain, you would pass a handler to the constructor.\n", + "- Request callbacks are most useful for use cases such as streaming, where you want to stream the output of a single request to a specific websocket connection, or other similar use cases. For example, if you want to stream the output of a single request to a websocket, you would pass a handler to the `call()` method" + ] + }, { "cell_type": "markdown", "id": "d3bf3304-43fb-47ad-ae50-0637a17018a2", "metadata": {}, "source": [ - "## Creating and Using a Custom `CallbackHandler`\n", + "## Using an existing handler\n", "\n", - "By default, a shared CallbackManager with the StdOutCallbackHandler will be used by models, chains, agents, and tools. However, you can pass in your own CallbackManager with a custom CallbackHandler:" + "LangChain provides a few built-in handlers that you can use to get started. These are available in the `langchain/callbacks` module. The most basic handler is the `StdOutCallbackHandler`, which simply logs all events to `stdout`. In the future we will add more default handlers to the library. \n", + "\n", + "**Note** when the `verbose` flag on the object is set to true, the `StdOutCallbackHandler` will be invoked even without being explicitly passed in." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "80532dfc-d687-4147-a0c9-1f90cc3e868c", "metadata": { "tags": [] @@ -155,16 +138,23 @@ "text": [ "\n", "\n", - "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "AgentAction(tool='Search', tool_input=\"US Open men's final 2019 winner\", log=' I need to find out who won the US Open men\\'s final in 2019 and then calculate his age raised to the 0.334 power.\\nAction: Search\\nAction Input: \"US Open men\\'s final 2019 winner\"')\n", - "Rafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\n", - "AgentAction(tool='Search', tool_input='Rafael Nadal age', log=' I need to find out the age of the winner\\nAction: Search\\nAction Input: \"Rafael Nadal age\"')\n", - "36 years\n", - "AgentAction(tool='Calculator', tool_input='36^0.334', log=' I now need to calculate his age raised to the 0.334 power\\nAction: Calculator\\nAction Input: 36^0.334')\n", - "Answer: 3.3098250249682484\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n", "\n", - " I now know the final answer\n", - "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", + "Prompt after formatting:\n", + "\u001b[32;1m\u001b[1;3m1 + 2 = \u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] @@ -172,117 +162,102 @@ { "data": { "text/plain": [ - "\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\"" + "'\\n\\n3'" ] }, - "execution_count": 1, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "from typing import Any, Dict, List, Optional, Union\n", - "\n", - "from langchain.agents import initialize_agent, load_tools\n", - "from langchain.agents import AgentType\n", - "from langchain.callbacks.base import CallbackManager, BaseCallbackHandler\n", + "from langchain.callbacks import StdOutCallbackHandler\n", + "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", - "from langchain.schema import AgentAction, AgentFinish, LLMResult\n", + "from langchain.prompts import PromptTemplate\n", "\n", - "class MyCustomCallbackHandler(BaseCallbackHandler):\n", - " \"\"\"Custom CallbackHandler.\"\"\"\n", + "handler = StdOutCallbackHandler()\n", + "llm = OpenAI()\n", + "prompt = PromptTemplate.from_template(\"1 + {number} = \")\n", "\n", - " def on_llm_start(\n", - " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Print out the prompts.\"\"\"\n", - " pass\n", + "# First, let's explicitly set the StdOutCallbackHandler in `callbacks`\n", + "chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])\n", + "chain.run(number=2)\n", "\n", - " def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + "# Then, let's use the `verbose` flag to achieve the same result\n", + "chain = LLMChain(llm=llm, prompt=prompt, verbose=True)\n", + "chain.run(number=2)\n", "\n", - " def on_llm_new_token(self, token: str, **kwargs: Any) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + "# Finally, let's use the request `callbacks` to achieve the same result\n", + "chain = LLMChain(llm=llm, prompt=prompt)\n", + "chain.run(number=2, callbacks=[handler])" + ] + }, + { + "cell_type": "markdown", + "id": "389c8448-5283-49e3-8c04-dbe1522e202c", + "metadata": {}, + "source": [ + "## Creating a custom handler\n", "\n", - " def on_llm_error(\n", - " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", + "You can create a custom handler to set on the object as well. In the example below, we'll implement streaming with a custom handler." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1b2e6588-0681-4cab-937a-7cc4790cea9a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "My custom handler, token: \n", + "My custom handler, token: Why\n", + "My custom handler, token: did\n", + "My custom handler, token: the\n", + "My custom handler, token: tomato\n", + "My custom handler, token: turn\n", + "My custom handler, token: red\n", + "My custom handler, token: ?\n", + "My custom handler, token: Because\n", + "My custom handler, token: it\n", + "My custom handler, token: saw\n", + "My custom handler, token: the\n", + "My custom handler, token: salad\n", + "My custom handler, token: dressing\n", + "My custom handler, token: !\n", + "My custom handler, token: \n" + ] + }, + { + "data": { + "text/plain": [ + "AIMessage(content='Why did the tomato turn red? Because it saw the salad dressing!', additional_kwargs={})" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.callbacks.base import BaseCallbackHandler\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.schema import HumanMessage\n", "\n", - " def on_chain_start(\n", - " self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Print out that we are entering a chain.\"\"\"\n", - " class_name = serialized[\"name\"]\n", - " print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n", + "class MyCustomHandler(BaseCallbackHandler):\n", + " def on_llm_new_token(self, token: str, **kwargs) -> None:\n", + " print(f\"My custom handler, token: {token}\")\n", "\n", - " def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n", - " \"\"\"Print out that we finished a chain.\"\"\"\n", - " print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n", + "# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n", + "# Additionally, we pass in a list with our custom handler\n", + "chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomHandler()])\n", "\n", - " def on_chain_error(\n", - " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", - "\n", - " def on_tool_start(\n", - " self,\n", - " serialized: Dict[str, Any],\n", - " input_str: str,\n", - " **kwargs: Any,\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", - "\n", - " def on_agent_action(\n", - " self, action: AgentAction, color: Optional[str] = None, **kwargs: Any\n", - " ) -> Any:\n", - " \"\"\"Run on agent action.\"\"\"\n", - " print(action)\n", - "\n", - " def on_tool_end(\n", - " self,\n", - " output: str,\n", - " color: Optional[str] = None,\n", - " observation_prefix: Optional[str] = None,\n", - " llm_prefix: Optional[str] = None,\n", - " **kwargs: Any,\n", - " ) -> None:\n", - " \"\"\"If not the final action, print out observation.\"\"\"\n", - " print(output)\n", - "\n", - " def on_tool_error(\n", - " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Do nothing.\"\"\"\n", - " pass\n", - "\n", - " def on_text(\n", - " self,\n", - " text: str,\n", - " color: Optional[str] = None,\n", - " end: str = \"\",\n", - " **kwargs: Optional[str],\n", - " ) -> None:\n", - " \"\"\"Run when agent ends.\"\"\"\n", - " print(text)\n", - "\n", - " def on_agent_finish(\n", - " self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Run on agent end.\"\"\"\n", - " print(finish.log)\n", - "manager = CallbackManager([MyCustomCallbackHandler()])\n", - "llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)\n", - "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, callback_manager=manager)\n", - "agent = initialize_agent(\n", - " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager\n", - ")\n", - "agent.run(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")" + "chat([HumanMessage(content=\"Tell me a joke\")])" ] }, { @@ -292,9 +267,11 @@ "tags": [] }, "source": [ - "## Async Support\n", + "## Async Callbacks\n", "\n", - "If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` and `AsyncCallbackManager` to avoid blocking the runloop." + "If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` to avoid blocking the runloop. \n", + "\n", + "**Advanced** if you use a sync `CallbackHandler` while using an async method to run your llm/chain/tool/agent, it will still work. However, under the hood, it will be called with [`run_in_executor`](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor) which can cause issues if your `CallbackHandler` is not thread-safe." ] }, { @@ -310,58 +287,589 @@ "output_type": "stream", "text": [ "zzzz....\n", + "Hi! I just woke up. Your llm is starting\n", + "Sync handler being called in a `thread_pool_executor`: token: \n", + "Sync handler being called in a `thread_pool_executor`: token: Why\n", + "Sync handler being called in a `thread_pool_executor`: token: don\n", + "Sync handler being called in a `thread_pool_executor`: token: 't\n", + "Sync handler being called in a `thread_pool_executor`: token: scientists\n", + "Sync handler being called in a `thread_pool_executor`: token: trust\n", + "Sync handler being called in a `thread_pool_executor`: token: atoms\n", + "Sync handler being called in a `thread_pool_executor`: token: ?\n", + "\n", + "\n", + "Sync handler being called in a `thread_pool_executor`: token: Because\n", + "Sync handler being called in a `thread_pool_executor`: token: they\n", + "Sync handler being called in a `thread_pool_executor`: token: make\n", + "Sync handler being called in a `thread_pool_executor`: token: up\n", + "Sync handler being called in a `thread_pool_executor`: token: everything\n", + "Sync handler being called in a `thread_pool_executor`: token: !\n", + "Sync handler being called in a `thread_pool_executor`: token: \n", + "zzzz....\n", + "Hi! I just woke up. Your llm is ending\n" + ] + }, + { + "data": { + "text/plain": [ + "LLMResult(generations=[[ChatGeneration(text=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", generation_info=None, message=AIMessage(content=\"Why don't scientists trust atoms?\\n\\nBecause they make up everything!\", additional_kwargs={}))]], llm_output={'token_usage': {}, 'model_name': 'gpt-3.5-turbo'})" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import asyncio\n", + "from typing import Any, Dict, List\n", + "from langchain.schema import LLMResult\n", + "from langchain.callbacks.base import AsyncCallbackHandler\n", + "\n", + "class MyCustomSyncHandler(BaseCallbackHandler):\n", + " def on_llm_new_token(self, token: str, **kwargs) -> None:\n", + " print(f\"Sync handler being called in a `thread_pool_executor`: token: {token}\")\n", + "\n", + "class MyCustomAsyncHandler(AsyncCallbackHandler):\n", + " \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n", + "\n", + " async def on_llm_start(\n", + " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", + " ) -> None:\n", + " \"\"\"Run when chain starts running.\"\"\"\n", + " print(\"zzzz....\")\n", + " await asyncio.sleep(0.3)\n", + " class_name = serialized[\"name\"]\n", + " print(\"Hi! I just woke up. Your llm is starting\")\n", + "\n", + " async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n", + " \"\"\"Run when chain ends running.\"\"\"\n", + " print(\"zzzz....\")\n", + " await asyncio.sleep(0.3)\n", + " print(\"Hi! I just woke up. Your llm is ending\")\n", + "\n", + "# To enable streaming, we pass in `streaming=True` to the ChatModel constructor\n", + "# Additionally, we pass in a list with our custom handler\n", + "chat = ChatOpenAI(max_tokens=25, streaming=True, callbacks=[MyCustomSyncHandler(), MyCustomAsyncHandler()])\n", + "\n", + "await chat.agenerate([[HumanMessage(content=\"Tell me a joke\")]])" + ] + }, + { + "cell_type": "markdown", + "id": "d26dbb34-fcc3-401c-a115-39c7620d2d65", + "metadata": {}, + "source": [ + "## Using multiple handlers, passing in handlers\n", + "\n", + "In the previous examples, we passed in callback handlers upon creation of an object by using `callbacks=`. In this case, the callbacks will be scoped to that particular object. \n", + "\n", + "However, in many cases, it is advantageous to pass in handlers instead when running the object. When we pass through `CallbackHandlers` using the `callbacks` keyword arg when executing an run, those callbacks will be issued by all nested objects involved in the execution. For example, when a handler is passed through to an `Agent`, it will be used for all callbacks related to the agent and all the objects involved in the agent's execution, in this case, the `Tools`, `LLMChain`, and `LLM`.\n", + "\n", + "This prevents us from having to manually attach the handlers to each individual nested object." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8eec8756-1828-45cb-9699-38ac8543a150", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "on_chain_start AgentExecutor\n", + "on_chain_start LLMChain\n", + "on_llm_start OpenAI\n", + "on_llm_start (I'm the second handler!!) OpenAI\n", + "on_new_token I\n", + "on_new_token need\n", + "on_new_token to\n", + "on_new_token use\n", + "on_new_token a\n", + "on_new_token calculator\n", + "on_new_token to\n", + "on_new_token solve\n", + "on_new_token this\n", + "on_new_token .\n", + "on_new_token \n", + "Action\n", + "on_new_token :\n", + "on_new_token Calculator\n", + "on_new_token \n", + "Action\n", + "on_new_token Input\n", + "on_new_token :\n", + "on_new_token 2\n", + "on_new_token ^\n", + "on_new_token 0\n", + "on_new_token .\n", + "on_new_token 235\n", + "on_new_token \n", + "on_agent_action AgentAction(tool='Calculator', tool_input='2^0.235', log=' I need to use a calculator to solve this.\\nAction: Calculator\\nAction Input: 2^0.235')\n", + "on_tool_start Calculator\n", + "on_chain_start LLMMathChain\n", + "on_chain_start LLMChain\n", + "on_llm_start OpenAI\n", + "on_llm_start (I'm the second handler!!) OpenAI\n", + "on_new_token \n", + "\n", + "on_new_token ```text\n", + "on_new_token \n", + "\n", + "on_new_token 2\n", + "on_new_token **\n", + "on_new_token 0\n", + "on_new_token .\n", + "on_new_token 235\n", + "on_new_token \n", + "\n", + "on_new_token ```\n", + "\n", + "on_new_token ...\n", + "on_new_token num\n", + "on_new_token expr\n", + "on_new_token .\n", + "on_new_token evaluate\n", + "on_new_token (\"\n", + "on_new_token 2\n", + "on_new_token **\n", + "on_new_token 0\n", + "on_new_token .\n", + "on_new_token 235\n", + "on_new_token \")\n", + "on_new_token ...\n", + "on_new_token \n", + "\n", + "on_new_token \n", + "on_chain_start LLMChain\n", + "on_llm_start OpenAI\n", + "on_llm_start (I'm the second handler!!) OpenAI\n", + "on_new_token I\n", + "on_new_token now\n", + "on_new_token know\n", + "on_new_token the\n", + "on_new_token final\n", + "on_new_token answer\n", + "on_new_token .\n", + "on_new_token \n", + "Final\n", + "on_new_token Answer\n", + "on_new_token :\n", + "on_new_token 1\n", + "on_new_token .\n", + "on_new_token 17\n", + "on_new_token 690\n", + "on_new_token 67\n", + "on_new_token 372\n", + "on_new_token 187\n", + "on_new_token 674\n", + "on_new_token \n" + ] + }, + { + "data": { + "text/plain": [ + "'1.1769067372187674'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from typing import Dict, Union, Any, List\n", + "\n", + "from langchain.callbacks.base import BaseCallbackHandler\n", + "from langchain.schema import AgentAction\n", + "from langchain.agents import AgentType, initialize_agent, load_tools\n", + "from langchain.callbacks import tracing_enabled\n", + "from langchain.llms import OpenAI\n", + "\n", + "# First, define custom callback handler implementations\n", + "class MyCustomHandlerOne(BaseCallbackHandler):\n", + " def on_llm_start(\n", + " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", + " ) -> Any:\n", + " print(f\"on_llm_start {serialized['name']}\")\n", + "\n", + " def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n", + " print(f\"on_new_token {token}\")\n", + "\n", + " def on_llm_error(\n", + " self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n", + " ) -> Any:\n", + " \"\"\"Run when LLM errors.\"\"\"\n", + "\n", + " def on_chain_start(\n", + " self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n", + " ) -> Any:\n", + " print(f\"on_chain_start {serialized['name']}\")\n", + "\n", + " def on_tool_start(\n", + " self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n", + " ) -> Any:\n", + " print(f\"on_tool_start {serialized['name']}\")\n", + "\n", + " def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n", + " print(f\"on_agent_action {action}\")\n", + "\n", + "class MyCustomHandlerTwo(BaseCallbackHandler):\n", + " def on_llm_start(\n", + " self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n", + " ) -> Any:\n", + " print(f\"on_llm_start (I'm the second handler!!) {serialized['name']}\")\n", + "\n", + "# Instantiate the handlers\n", + "handler1 = MyCustomHandlerOne()\n", + "handler2 = MyCustomHandlerTwo()\n", + "\n", + "# Setup the agent. Only the `llm` will issue callbacks for handler2\n", + "llm = OpenAI(temperature=0, streaming=True, callbacks=[handler2])\n", + "tools = load_tools([\"llm-math\"], llm=llm)\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION\n", + ")\n", + "\n", + "# Callbacks for handler1 will be issued by every object involved in the \n", + "# Agent execution (llm, llmchain, tool, agent executor)\n", + "agent.run(\"What is 2 raised to the 0.235 power?\", callbacks=[handler1])" + ] + }, + { + "cell_type": "markdown", + "id": "32b29135-f852-4492-88ed-547275c72c53", + "metadata": {}, + "source": [ + "# Tracing and Token Counting" + ] + }, + { + "cell_type": "markdown", + "id": "fbb606d6-2863-46c5-8347-9f0bdb3805bb", + "metadata": {}, + "source": [ + "Tracing and token counting are two capabilities we provide which are built on our callbacks mechanism." + ] + }, + { + "cell_type": "markdown", + "id": "f62cd10c-494c-47d6-aa98-6e926cb9c456", + "metadata": {}, + "source": [ + "## Tracing" + ] + }, + { + "cell_type": "markdown", + "id": "d5a74b3f-3769-4a4f-99c7-b6a3b20a94e2", + "metadata": {}, + "source": [ + "There are two recommended ways to trace your LangChains:\n", + "\n", + "1. Setting the `LANGCHAIN_TRACING` environment variable to `\"true\"`. \n", + "2. Using a context manager `with tracing_enabled()` to trace a particular block of code.\n", + "\n", + "**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f164dfd5-d987-4b6a-a7c8-019c651ce47f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from langchain.agents import AgentType, initialize_agent, load_tools\n", + "from langchain.callbacks import tracing_enabled\n", + "from langchain.llms import OpenAI\n", + "\n", + "# To run the code, make sure to set OPENAI_API_KEY and SERPAPI_API_KEY\n", + "llm = OpenAI(temperature=0)\n", + "tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm)\n", + "agent = initialize_agent(\n", + " tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True\n", + ")\n", + "\n", + "questions = [\n", + " \"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\",\n", + " \"Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?\",\n", + " \"Who won the most recent formula 1 grand prix? What is their age raised to the 0.23 power?\",\n", + " \"Who won the US Open women's final in 2019? What is her age raised to the 0.34 power?\",\n", + " \"Who is Beyonce's husband? What is his age raised to the 0.19 power?\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6be7777e-ec1d-438f-ae33-3a93c45f808e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "zzzz....\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", + "Action: Search\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", + "Action: Calculator\n", + "Action Input: 36^0.334\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action: Calculator\n", + "Action Input: 29^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", + "Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] } ], "source": [ - "import asyncio\n", - "from aiohttp import ClientSession\n", + "os.environ[\"LANGCHAIN_TRACING\"] = \"true\"\n", "\n", - "from langchain.callbacks.base import AsyncCallbackHandler, AsyncCallbackManager\n", - "\n", - "class MyCustomAsyncCallbackHandler(AsyncCallbackHandler):\n", - " \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n", - "\n", - " async def on_chain_start(\n", - " self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n", - " ) -> None:\n", - " \"\"\"Run when chain starts running.\"\"\"\n", - " print(\"zzzz....\")\n", - " await asyncio.sleep(0.5)\n", - " class_name = serialized[\"name\"]\n", - " print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n", - "\n", - " async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n", - " \"\"\"Run when chain ends running.\"\"\"\n", - " print(\"zzzz....\")\n", - " await asyncio.sleep(0.5)\n", - " print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n", - "\n", - "manager = AsyncCallbackManager([MyCustomAsyncCallbackHandler()])\n", - "\n", - "# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n", - "# but you must manually close the client session at the end of your program/event loop\n", - "aiosession = ClientSession()\n", - "llm = OpenAI(temperature=0, callback_manager=manager)\n", - "async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n", - "async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n", - "await async_agent.arun(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")\n", - "await aiosession.close()" + "# Both of the agent runs will be traced because the environment variable is set\n", + "agent.run(questions[0])\n", + "with tracing_enabled() as session:\n", + " assert session\n", + " agent.run(questions[1])" ] }, { "cell_type": "code", - "execution_count": null, - "id": "86be6304-e433-4048-880c-a92a73244407", + "execution_count": 10, + "id": "a6fd6026-dc1e-4d48-893d-3592539c7828", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", + "Action: Search\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m36 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", + "Action: Calculator\n", + "Action Input: 36^0.334\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer\n", + "Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\n", + "Observation: \u001b[33;1m\u001b[1;3m29 years\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action: Calculator\n", + "Action Input: 29^0.23\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", + "Final Answer: Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Harry Styles is Olivia Wilde's boyfriend and his current age raised to the 0.23 power is 2.169459462491557.\"" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now, we unset the environment variable and use a context manager.\n", + "\n", + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", + "\n", + "# here, we are writing traces to \"my_test_session\"\n", + "with tracing_enabled(\"my_test_session\") as session:\n", + " assert session\n", + " agent.run(questions[0]) # this should be traced\n", + "\n", + "agent.run(questions[1]) # this should not be traced" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9383a351-4983-44e9-abd7-ef942e1c65c4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\u001b[32;1m\u001b[1;3m I need to find out who won the grand prix and then calculate their age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Formula 1 Grand Prix Winner\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who won the US Open men's final in 2019 and then calculate his age raised to the 0.334 power.\n", + "Action: Search\n", + "Action Input: \"US Open men's final 2019 winner\"\u001b[0m\u001b[33;1m\u001b[1;3mRafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out who Olivia Wilde's boyfriend is and then calculate his age raised to the 0.23 power.\n", + "Action: Search\n", + "Action Input: \"Olivia Wilde boyfriend\"\u001b[0m\u001b[33;1m\u001b[1;3mSudeikis and Wilde's relationship ended in November 2020. Wilde was publicly served with court documents regarding child custody while she was presenting Don't Worry Darling at CinemaCon 2022. In January 2021, Wilde began dating singer Harry Styles after meeting during the filming of Don't Worry Darling.\u001b[0m\u001b[33;1m\u001b[1;3mLewis Hamilton has won 103 Grands Prix during his career. He won 21 races with McLaren and has won 82 with Mercedes. Lewis Hamilton holds the record for the ...\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out the age of the winner\n", + "Action: Search\n", + "Action Input: \"Rafael Nadal age\"\u001b[0m\u001b[33;1m\u001b[1;3m36 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Harry Styles' age.\n", + "Action: Search\n", + "Action Input: \"Harry Styles age\"\u001b[0m\u001b[32;1m\u001b[1;3m I need to find out Lewis Hamilton's age\n", + "Action: Search\n", + "Action Input: \"Lewis Hamilton Age\"\u001b[0m\u001b[33;1m\u001b[1;3m29 years\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate the age raised to the 0.334 power\n", + "Action: Calculator\n", + "Action Input: 36^0.334\u001b[0m\u001b[32;1m\u001b[1;3m I need to calculate 29 raised to the 0.23 power.\n", + "Action: Calculator\n", + "Action Input: 29^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 3.3098250249682484\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.169459462491557\u001b[0m\u001b[33;1m\u001b[1;3m38 years\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m I now need to calculate 38 raised to the 0.23 power\n", + "Action: Calculator\n", + "Action Input: 38^0.23\u001b[0m\u001b[36;1m\u001b[1;3mAnswer: 2.3086081644669734\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\"" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The context manager is concurrency safe:\n", + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", + "\n", + "# start a background task\n", + "task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced\n", + "with tracing_enabled() as session:\n", + " assert session\n", + " tasks = [agent.arun(q) for q in questions[1:3]] # these should be traced\n", + " await asyncio.gather(*tasks)\n", + "\n", + "await task" + ] + }, + { + "cell_type": "markdown", + "id": "254fef1b-6b6e-4352-9cf4-363fba895ac7", "metadata": {}, + "source": [ + "## Token Counting\n", + "LangChain offers a context manager that allows you to count tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5c3e0b89-2c5e-4036-bdf2-fb6b750e360c", + "metadata": { + "tags": [] + }, "outputs": [], - "source": [] + "source": [ + "from langchain.callbacks import get_openai_callback\n", + "\n", + "llm = OpenAI(temperature=0)\n", + "with get_openai_callback() as cb:\n", + " llm(\"What is the square root of 4?\")\n", + "\n", + "total_tokens = cb.total_tokens\n", + "assert total_tokens > 0\n", + "\n", + "with get_openai_callback() as cb:\n", + " llm(\"What is the square root of 4?\")\n", + " llm(\"What is the square root of 4?\")\n", + "\n", + "assert cb.total_tokens == total_tokens * 2\n", + "\n", + "# You can kick off concurrent runs from within the context manager\n", + "with get_openai_callback() as cb:\n", + " await asyncio.gather(\n", + " *[llm.agenerate([\"What is the square root of 4?\"]) for _ in range(3)]\n", + " )\n", + "\n", + "assert cb.total_tokens == total_tokens * 3\n", + "\n", + "# The context manager is concurrency safe\n", + "task = asyncio.create_task(llm.agenerate([\"What is the square root of 4?\"]))\n", + "with get_openai_callback() as cb:\n", + " await llm.agenerate([\"What is the square root of 4?\"])\n", + "\n", + "await task\n", + "assert cb.total_tokens == total_tokens" + ] } ], "metadata": { diff --git a/docs/modules/chains/examples/llm_bash.ipynb b/docs/modules/chains/examples/llm_bash.ipynb index c2cb0fe6..dab1f6e4 100644 --- a/docs/modules/chains/examples/llm_bash.ipynb +++ b/docs/modules/chains/examples/llm_bash.ipynb @@ -10,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -37,7 +37,7 @@ "'Hello World\\n'" ] }, - "execution_count": 1, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -50,7 +50,7 @@ "\n", "text = \"Please write a bash script that prints 'Hello World' to the console.\"\n", "\n", - "bash_chain = LLMBashChain(llm=llm, verbose=True)\n", + "bash_chain = LLMBashChain.from_llm(llm, verbose=True)\n", "\n", "bash_chain.run(text)" ] @@ -65,11 +65,12 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from langchain.prompts.prompt import PromptTemplate\n", + "from langchain.chains.llm_bash.prompt import BashOutputParser\n", "\n", "_PROMPT_TEMPLATE = \"\"\"If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put \"#!/bin/bash\" in your answer. Make sure to reason step by step, using this format:\n", "Question: \"copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'\"\n", @@ -88,12 +89,12 @@ "That is the format. Begin!\n", "Question: {question}\"\"\"\n", "\n", - "PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE)" + "PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE, output_parser=BashOutputParser())" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -120,13 +121,13 @@ "'Hello World\\n'" ] }, - "execution_count": 3, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "bash_chain = LLMBashChain(llm=llm, prompt=PROMPT, verbose=True)\n", + "bash_chain = LLMBashChain.from_llm(llm, prompt=PROMPT, verbose=True)\n", "\n", "text = \"Please write a bash script that prints 'Hello World' to the console.\"\n", "\n", @@ -134,7 +135,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -145,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -177,7 +177,7 @@ "'api.ipynb\\t\\t\\tllm_summarization_checker.ipynb\\r\\nconstitutional_chain.ipynb\\tmoderation.ipynb\\r\\nllm_bash.ipynb\\t\\t\\topenai_openapi.yaml\\r\\nllm_checker.ipynb\\t\\topenapi.ipynb\\r\\nllm_math.ipynb\\t\\t\\tpal.ipynb\\r\\nllm_requests.ipynb\\t\\tsqlite.ipynb'" ] }, - "execution_count": 4, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -187,7 +187,7 @@ "\n", "\n", "persistent_process = BashProcess(persistent=True)\n", - "bash_chain = LLMBashChain.from_bash_process(llm=llm, bash_process=persistent_process, verbose=True)\n", + "bash_chain = LLMBashChain.from_llm(llm, bash_process=persistent_process, verbose=True)\n", "\n", "text = \"List the current directory then move up a level.\"\n", "\n", @@ -196,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -224,7 +224,7 @@ "'examples\\t\\tgetting_started.ipynb\\tindex_examples\\r\\ngeneric\\t\\t\\thow_to_guides.rst'" ] }, - "execution_count": 5, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -258,7 +258,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/llm_checker.ipynb b/docs/modules/chains/examples/llm_checker.ipynb index a6bc0b73..38ed1b64 100644 --- a/docs/modules/chains/examples/llm_checker.ipynb +++ b/docs/modules/chains/examples/llm_checker.ipynb @@ -23,28 +23,16 @@ "\n", "\n", "\u001b[1m> Entering new SequentialChain chain...\u001b[0m\n", - "\u001b[1mChain 0\u001b[0m:\n", - "{'statement': '\\nNone. Mammals do not lay eggs.'}\n", "\n", - "\u001b[1mChain 1\u001b[0m:\n", - "{'assertions': '\\n• Mammals reproduce using live birth\\n• Mammals do not lay eggs\\n• Animals that lay eggs are not mammals'}\n", + "\u001b[1m> Finished chain.\u001b[0m\n", "\n", - "\u001b[1mChain 2\u001b[0m:\n", - "{'checked_assertions': '\\n1. True\\n\\n2. True\\n\\n3. False - Mammals are a class of animals that includes animals that lay eggs, such as monotremes (platypus and echidna).'}\n", - "\n", - "\u001b[1mChain 3\u001b[0m:\n", - "{'revised_statement': ' Monotremes, such as the platypus and echidna, lay the biggest eggs of any mammal.'}\n", - "\n", - "\n", - "\u001b[1m> Finished SequentialChain chain.\u001b[0m\n", - "\n", - "\u001b[1m> Finished LLMCheckerChain chain.\u001b[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "' Monotremes, such as the platypus and echidna, lay the biggest eggs of any mammal.'" + "' No mammal lays the biggest eggs. The Elephant Bird, which was a species of giant bird, laid the largest eggs of any bird.'" ] }, "execution_count": 1, @@ -60,7 +48,7 @@ "\n", "text = \"What type of mammal lays the biggest eggs?\"\n", "\n", - "checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n", + "checker_chain = LLMCheckerChain.from_llm(llm, verbose=True)\n", "\n", "checker_chain.run(text)" ] @@ -89,7 +77,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/llm_math.ipynb b/docs/modules/chains/examples/llm_math.ipynb index 29eaaea1..c46f825e 100644 --- a/docs/modules/chains/examples/llm_math.ipynb +++ b/docs/modules/chains/examples/llm_math.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "id": "44e9ba31", "metadata": {}, "outputs": [ @@ -24,23 +24,22 @@ "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "What is 13 raised to the .3432 power?\u001b[32;1m\u001b[1;3m\n", - "```python\n", - "import math\n", - "print(math.pow(13, .3432))\n", + "```text\n", + "13 ** .3432\n", "```\n", + "...numexpr.evaluate(\"13 ** .3432\")...\n", "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\n", - "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ - "'Answer: 2.4116004626599237\\n'" + "'Answer: 2.4116004626599237'" ] }, - "execution_count": 1, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -49,102 +48,7 @@ "from langchain import OpenAI, LLMMathChain\n", "\n", "llm = OpenAI(temperature=0)\n", - "llm_math = LLMMathChain(llm=llm, verbose=True)\n", - "\n", - "llm_math.run(\"What is 13 raised to the .3432 power?\")" - ] - }, - { - "cell_type": "markdown", - "id": "2bdd5fc6", - "metadata": {}, - "source": [ - "## Customize Prompt\n", - "You can also customize the prompt that is used. Here is an example prompting it to use numpy" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "76be17b0", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain.prompts.prompt import PromptTemplate\n", - "\n", - "_PROMPT_TEMPLATE = \"\"\"You are GPT-3, and you can't do math.\n", - "\n", - "You can do basic math, and your memorization abilities are impressive, but you can't do any complex calculations that a human could not do in their head. You also have an annoying tendency to just make up highly specific, but wrong, answers.\n", - "\n", - "So we hooked you up to a Python 3 kernel, and now you can execute code. If you execute code, you must print out the final answer using the print function. You MUST use the python package numpy to answer your question. You must import numpy as np.\n", - "\n", - "\n", - "Question: ${{Question with hard calculation.}}\n", - "```python\n", - "${{Code that prints what you need to know}}\n", - "print(${{code}})\n", - "```\n", - "```output\n", - "${{Output of your code}}\n", - "```\n", - "Answer: ${{Answer}}\n", - "\n", - "Begin.\n", - "\n", - "Question: What is 37593 * 67?\n", - "\n", - "```python\n", - "import numpy as np\n", - "print(np.multiply(37593, 67))\n", - "```\n", - "```output\n", - "2518731\n", - "```\n", - "Answer: 2518731\n", - "\n", - "Question: {question}\"\"\"\n", - "\n", - "PROMPT = PromptTemplate(input_variables=[\"question\"], template=_PROMPT_TEMPLATE)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "0c42faa0", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", - "What is 13 raised to the .3432 power?\u001b[32;1m\u001b[1;3m\n", - "\n", - "```python\n", - "import numpy as np\n", - "print(np.power(13, .3432))\n", - "```\n", - "\u001b[0m\n", - "Answer: \u001b[33;1m\u001b[1;3m2.4116004626599237\n", - "\u001b[0m\n", - "\u001b[1m> Finished chain.\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "'Answer: 2.4116004626599237\\n'" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "llm_math = LLMMathChain(llm=llm, prompt=PROMPT, verbose=True)\n", + "llm_math = LLMMathChain.from_llm(llm, verbose=True)\n", "\n", "llm_math.run(\"What is 13 raised to the .3432 power?\")" ] @@ -152,7 +56,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0c62951b", + "id": "e978bb8e", "metadata": {}, "outputs": [], "source": [] @@ -174,7 +78,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/llm_summarization_checker.ipynb b/docs/modules/chains/examples/llm_summarization_checker.ipynb index 7448f84f..7436616e 100644 --- a/docs/modules/chains/examples/llm_summarization_checker.ipynb +++ b/docs/modules/chains/examples/llm_summarization_checker.ipynb @@ -221,11 +221,11 @@ "\n", "• The light from these galaxies has been traveling for over 13 billion years to reach us. - True \n", "\n", - "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 1995. \n", + "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 2004. \n", "\n", "• Exoplanets were first discovered in 1992. - True \n", "\n", - "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. It is too early to tell as the JWST has not been launched yet.\n", + "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. The JWST has not yet been launched, so it is not yet known how much detail it will be able to provide.\n", "\"\"\"\n", "\n", "Original Summary:\n", @@ -296,11 +296,11 @@ "\n", "• The light from these galaxies has been traveling for over 13 billion years to reach us. - True \n", "\n", - "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 1995. \n", + "• JWST has provided us with the first images of exoplanets, which are planets outside of our own solar system. - False. The first exoplanet was discovered in 1992, but the first images of exoplanets were taken by the Hubble Space Telescope in 2004. \n", "\n", "• Exoplanets were first discovered in 1992. - True \n", "\n", - "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. It is too early to tell as the JWST has not been launched yet.\n", + "• The JWST has allowed us to see exoplanets in greater detail. - Undetermined. The JWST has not yet been launched, so it is not yet known how much detail it will be able to provide.\n", "\"\"\"\n", "Result:\u001b[0m\n", "\n", @@ -312,7 +312,7 @@ "Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\n", "• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\n", "• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\n", - "• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail than ever before.\n", + "• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail when it is launched in 2023.\n", "These discoveries can spark a child's imagination about the infinite wonders of the universe.\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -321,7 +321,7 @@ { "data": { "text/plain": [ - "'Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\\n• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\\n• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\\n• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail than ever before.\\nThese discoveries can spark a child\\'s imagination about the infinite wonders of the universe.'" + "'Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\\n• In 2023, The JWST will spot a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\\n• The telescope will capture images of galaxies that are over 13 billion years old. This means that the light from these galaxies has been traveling for over 13 billion years to reach us.\\n• Exoplanets, which are planets outside of our own solar system, were first discovered in 1992. The JWST will allow us to see them in greater detail when it is launched in 2023.\\nThese discoveries can spark a child\\'s imagination about the infinite wonders of the universe.'" ] }, "execution_count": 1, @@ -334,7 +334,7 @@ "from langchain.llms import OpenAI\n", "\n", "llm = OpenAI(temperature=0)\n", - "checker_chain = LLMSummarizationCheckerChain(llm=llm, verbose=True, max_checks=2)\n", + "checker_chain = LLMSummarizationCheckerChain.from_llm(llm, verbose=True, max_checks=2)\n", "text = \"\"\"\n", "Your 9-year old might like these recent discoveries made by The James Webb Space Telescope (JWST):\n", "• In 2023, The JWST spotted a number of galaxies nicknamed \"green peas.\" They were given this name because they are small, round, and green, like peas.\n", @@ -407,7 +407,8 @@ "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n", "\n", - "Checked Assertions:\"\"\"\n", + "Checked Assertions:\n", + "\"\"\"\n", "\n", "- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n", "\n", @@ -428,7 +429,8 @@ "- It is considered the northern branch of the Norwegian Sea. True\n", "\"\"\"\n", "\n", - "Original Summary:\"\"\"\n", + "Original Summary:\n", + "\"\"\"\n", "The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is one of five oceans in the world, alongside the Pacific Ocean, Atlantic Ocean, Indian Ocean, and the Southern Ocean. It is the smallest of the five oceans and is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\n", "\"\"\"\n", "\n", @@ -443,7 +445,7 @@ "\n", "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", - "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n", + "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n", "\n", "If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n", "\n", @@ -555,7 +557,8 @@ "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n", "\n", - "Checked Assertions:\"\"\"\n", + "Checked Assertions:\n", + "\"\"\"\n", "\n", "- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n", "\n", @@ -574,7 +577,8 @@ "- It is considered the northern branch of the Norwegian Sea. False - It is considered the northern branch of the Atlantic Ocean.\n", "\"\"\"\n", "\n", - "Original Summary:\"\"\"\n", + "Original Summary:\n", + "\"\"\"\n", "\n", "The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is an arm of the Arctic Ocean. It is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\n", "\"\"\"\n", @@ -583,14 +587,20 @@ "\n", "The output should have the same structure and formatting as the original summary.\n", "\n", - "Summary:\u001b[0m\n", + "Summary:\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\n", "\n", "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", - "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n", + "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n", "\n", "If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n", "\n", @@ -701,7 +711,8 @@ "Prompt after formatting:\n", "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.\n", "\n", - "Checked Assertions:\"\"\"\n", + "Checked Assertions:\n", + "\"\"\"\n", "\n", "- The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. True\n", "\n", @@ -718,7 +729,8 @@ "- It is considered the northern branch of the Atlantic Ocean. False - The Greenland Sea is considered part of the Arctic Ocean, not the Atlantic Ocean.\n", "\"\"\"\n", "\n", - "Original Summary:\"\"\"\n", + "Original Summary:\n", + "\"\"\"\n", "\n", "\n", "The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is an arm of the Arctic Ocean. It is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the country of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Atlantic Ocean.\n", @@ -735,7 +747,7 @@ "\n", "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n", "Prompt after formatting:\n", - "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true of false.\n", + "\u001b[32;1m\u001b[1;3mBelow are some assertions that have been fact checked and are labeled as true or false.\n", "\n", "If all of the assertions are true, return \"True\". If any of the assertions are false, return \"False\".\n", "\n", @@ -813,14 +825,14 @@ "from langchain.llms import OpenAI\n", "\n", "llm = OpenAI(temperature=0)\n", - "checker_chain = LLMSummarizationCheckerChain(llm=llm, verbose=True, max_checks=3)\n", + "checker_chain = LLMSummarizationCheckerChain.from_llm(llm, verbose=True, max_checks=3)\n", "text = \"The Greenland Sea is an outlying portion of the Arctic Ocean located between Iceland, Norway, the Svalbard archipelago and Greenland. It has an area of 465,000 square miles and is one of five oceans in the world, alongside the Pacific Ocean, Atlantic Ocean, Indian Ocean, and the Southern Ocean. It is the smallest of the five oceans and is covered almost entirely by water, some of which is frozen in the form of glaciers and icebergs. The sea is named after the island of Greenland, and is the Arctic Ocean's main outlet to the Atlantic. It is often frozen over so navigation is limited, and is considered the northern branch of the Norwegian Sea.\"\n", "checker_chain.run(text)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -1077,7 +1089,7 @@ "'Birds are not mammals, but they are a class of their own. They lay eggs, unlike mammals which give birth to live young.'" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -1087,17 +1099,10 @@ "from langchain.llms import OpenAI\n", "\n", "llm = OpenAI(temperature=0)\n", - "checker_chain = LLMSummarizationCheckerChain(llm=llm, max_checks=3, verbose=True)\n", + "checker_chain = LLMSummarizationCheckerChain.from_llm(llm, max_checks=3, verbose=True)\n", "text = \"Mammals can lay eggs, birds can lay eggs, therefore birds are mammals.\"\n", "checker_chain.run(text)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/modules/chains/examples/pal.ipynb b/docs/modules/chains/examples/pal.ipynb index 36b58072..94942ccb 100644 --- a/docs/modules/chains/examples/pal.ipynb +++ b/docs/modules/chains/examples/pal.ipynb @@ -28,7 +28,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)" + "llm = OpenAI(temperature=0, max_tokens=512)" ] }, { @@ -63,7 +63,9 @@ "cell_type": "code", "execution_count": 4, "id": "3ef64b27", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", @@ -71,17 +73,17 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new PALChain chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3mdef solution():\n", + "\u001b[1m> Entering new PALChain chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mdef solution():\n", " \"\"\"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\"\"\"\n", " cindy_pets = 4\n", " marcia_pets = cindy_pets + 2\n", " jan_pets = marcia_pets * 3\n", " total_pets = cindy_pets + marcia_pets + jan_pets\n", " result = total_pets\n", - " return result\u001B[0m\n", + " return result\u001b[0m\n", "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { @@ -139,8 +141,8 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new PALChain chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3m# Put objects into a list to record ordering\n", + "\u001b[1m> Entering new PALChain chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n", "objects = []\n", "objects += [('booklet', 'blue')] * 2\n", "objects += [('booklet', 'purple')] * 2\n", @@ -151,9 +153,9 @@ "\n", "# Count number of purple objects\n", "num_purple = len([object for object in objects if object[1] == 'purple'])\n", - "answer = num_purple\u001B[0m\n", + "answer = num_purple\u001b[0m\n", "\n", - "\u001B[1m> Finished PALChain chain.\u001B[0m\n" + "\u001b[1m> Finished PALChain chain.\u001b[0m\n" ] }, { @@ -212,8 +214,8 @@ "text": [ "\n", "\n", - "\u001B[1m> Entering new PALChain chain...\u001B[0m\n", - "\u001B[32;1m\u001B[1;3m# Put objects into a list to record ordering\n", + "\u001b[1m> Entering new PALChain chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n", "objects = []\n", "objects += [('booklet', 'blue')] * 2\n", "objects += [('booklet', 'purple')] * 2\n", @@ -224,9 +226,9 @@ "\n", "# Count number of purple objects\n", "num_purple = len([object for object in objects if object[1] == 'purple'])\n", - "answer = num_purple\u001B[0m\n", + "answer = num_purple\u001b[0m\n", "\n", - "\u001B[1m> Finished chain.\u001B[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] } ], @@ -280,7 +282,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/examples/sqlite.ipynb b/docs/modules/chains/examples/sqlite.ipynb index b3b23eb4..472ac99e 100644 --- a/docs/modules/chains/examples/sqlite.ipynb +++ b/docs/modules/chains/examples/sqlite.ipynb @@ -73,7 +73,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)" ] }, { @@ -175,7 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)" ] }, { @@ -230,7 +230,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, return_intermediate_steps=True)" ] }, { @@ -285,7 +285,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, top_k=3)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, top_k=3)" ] }, { @@ -407,7 +407,7 @@ "metadata": {}, "outputs": [], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)" + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)" ] }, { @@ -569,7 +569,7 @@ } ], "source": [ - "db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)\n", + "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n", "db_chain.run(\"What are some example tracks by Bach?\")" ] }, @@ -681,7 +681,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/modules/chains/generic/custom_chain.ipynb b/docs/modules/chains/generic/custom_chain.ipynb new file mode 100644 index 00000000..4916b14c --- /dev/null +++ b/docs/modules/chains/generic/custom_chain.ipynb @@ -0,0 +1,199 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "593f7553-7038-498e-96d4-8255e5ce34f0", + "metadata": {}, + "source": [ + "# Creating a custom Chain\n", + "\n", + "To implement your own custom chain you can subclass `Chain` and implement the following methods:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c19c736e-ca74-4726-bb77-0a849bcc2960", + "metadata": { + "tags": [], + "vscode": { + "languageId": "python" + } + }, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "from typing import Any, Dict, List, Optional\n", + "\n", + "from pydantic import Extra\n", + "\n", + "from langchain.base_language import BaseLanguageModel\n", + "from langchain.callbacks.manager import (\n", + " AsyncCallbackManagerForChainRun,\n", + " CallbackManagerForChainRun,\n", + ")\n", + "from langchain.chains.base import Chain\n", + "from langchain.prompts.base import BasePromptTemplate\n", + "\n", + "\n", + "class MyCustomChain(Chain):\n", + " \"\"\"\n", + " An example of a custom chain.\n", + " \"\"\"\n", + "\n", + " prompt: BasePromptTemplate\n", + " \"\"\"Prompt object to use.\"\"\"\n", + " llm: BaseLanguageModel\n", + " output_key: str = \"text\" #: :meta private:\n", + "\n", + " class Config:\n", + " \"\"\"Configuration for this pydantic object.\"\"\"\n", + "\n", + " extra = Extra.forbid\n", + " arbitrary_types_allowed = True\n", + "\n", + " @property\n", + " def input_keys(self) -> List[str]:\n", + " \"\"\"Will be whatever keys the prompt expects.\n", + "\n", + " :meta private:\n", + " \"\"\"\n", + " return self.prompt.input_variables\n", + "\n", + " @property\n", + " def output_keys(self) -> List[str]:\n", + " \"\"\"Will always return text key.\n", + "\n", + " :meta private:\n", + " \"\"\"\n", + " return [self.output_key]\n", + "\n", + " def _call(\n", + " self,\n", + " inputs: Dict[str, Any],\n", + " run_manager: Optional[CallbackManagerForChainRun] = None,\n", + " ) -> Dict[str, str]:\n", + " # Your custom chain logic goes here\n", + " # This is just an example that mimics LLMChain\n", + " prompt_value = self.prompt.format_prompt(**inputs)\n", + " \n", + " # Whenever you call a language model, or another chain, you should pass\n", + " # a callback manager to it. This allows the inner run to be tracked by\n", + " # any callbacks that are registered on the outer run.\n", + " # You can always obtain a callback manager for this by calling\n", + " # `run_manager.get_child()` as shown below.\n", + " response = self.llm.generate_prompt(\n", + " [prompt_value],\n", + " callbacks=run_manager.get_child() if run_manager else None\n", + " )\n", + "\n", + " # If you want to log something about this run, you can do so by calling\n", + " # methods on the `run_manager`, as shown below. This will trigger any\n", + " # callbacks that are registered for that event.\n", + " if run_manager:\n", + " run_manager.on_text(\"Log something about this run\")\n", + " \n", + " return {self.output_key: response.generations[0][0].text}\n", + "\n", + " async def _acall(\n", + " self,\n", + " inputs: Dict[str, Any],\n", + " run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n", + " ) -> Dict[str, str]:\n", + " # Your custom chain logic goes here\n", + " # This is just an example that mimics LLMChain\n", + " prompt_value = self.prompt.format_prompt(**inputs)\n", + " \n", + " # Whenever you call a language model, or another chain, you should pass\n", + " # a callback manager to it. This allows the inner run to be tracked by\n", + " # any callbacks that are registered on the outer run.\n", + " # You can always obtain a callback manager for this by calling\n", + " # `run_manager.get_child()` as shown below.\n", + " response = await self.llm.agenerate_prompt(\n", + " [prompt_value],\n", + " callbacks=run_manager.get_child() if run_manager else None\n", + " )\n", + "\n", + " # If you want to log something about this run, you can do so by calling\n", + " # methods on the `run_manager`, as shown below. This will trigger any\n", + " # callbacks that are registered for that event.\n", + " if run_manager:\n", + " await run_manager.on_text(\"Log something about this run\")\n", + " \n", + " return {self.output_key: response.generations[0][0].text}\n", + "\n", + " @property\n", + " def _chain_type(self) -> str:\n", + " return \"my_custom_chain\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "18361f89", + "metadata": { + "vscode": { + "languageId": "python" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n", + "Log something about this run\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain.callbacks.stdout import StdOutCallbackHandler\n", + "from langchain.chat_models.openai import ChatOpenAI\n", + "from langchain.prompts.prompt import PromptTemplate\n", + "\n", + "\n", + "chain = MyCustomChain(\n", + " prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n", + " llm=ChatOpenAI()\n", + ")\n", + "\n", + "chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/chains/index_examples/chat_vector_db.ipynb b/docs/modules/chains/index_examples/chat_vector_db.ipynb index e86aa162..de329013 100644 --- a/docs/modules/chains/index_examples/chat_vector_db.ipynb +++ b/docs/modules/chains/index_examples/chat_vector_db.ipynb @@ -589,7 +589,6 @@ "outputs": [], "source": [ "from langchain.chains.llm import LLMChain\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n", "from langchain.chains.question_answering import load_qa_chain\n", @@ -597,7 +596,7 @@ "# Construct a ConversationalRetrievalChain with a streaming llm for combine docs\n", "# and a separate, non-streaming llm for question generation\n", "llm = OpenAI(temperature=0)\n", - "streaming_llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "streaming_llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "\n", "question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)\n", "doc_chain = load_qa_chain(streaming_llm, chain_type=\"stuff\", prompt=QA_PROMPT)\n", diff --git a/docs/modules/models/chat/examples/streaming.ipynb b/docs/modules/models/chat/examples/streaming.ipynb index 22b27e0c..e7d0894e 100644 --- a/docs/modules/models/chat/examples/streaming.ipynb +++ b/docs/modules/models/chat/examples/streaming.ipynb @@ -80,9 +80,8 @@ } ], "source": [ - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", - "chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])" ] }, diff --git a/docs/modules/models/chat/getting_started.ipynb b/docs/modules/models/chat/getting_started.ipynb index 113d652e..cee995ec 100644 --- a/docs/modules/models/chat/getting_started.ipynb +++ b/docs/modules/models/chat/getting_started.ipynb @@ -373,9 +373,8 @@ } ], "source": [ - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", - "chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])\n" ] }, diff --git a/docs/modules/models/llms/examples/custom_llm.ipynb b/docs/modules/models/llms/examples/custom_llm.ipynb index 1375d639..4db92f04 100644 --- a/docs/modules/models/llms/examples/custom_llm.ipynb +++ b/docs/modules/models/llms/examples/custom_llm.ipynb @@ -22,18 +22,20 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "id": "a65696a0", "metadata": {}, "outputs": [], "source": [ - "from langchain.llms.base import LLM\n", - "from typing import Optional, List, Mapping, Any" + "from typing import Any, List, Mapping, Optional\n", + "\n", + "from langchain.callbacks.manager import CallbackManagerForLLMRun\n", + "from langchain.llms.base import LLM" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "id": "d5ceff02", "metadata": {}, "outputs": [], @@ -46,7 +48,12 @@ " def _llm_type(self) -> str:\n", " return \"custom\"\n", " \n", - " def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:\n", + " def _call(\n", + " self,\n", + " prompt: str,\n", + " stop: Optional[List[str]] = None,\n", + " run_manager: Optional[CallbackManagerForLLMRun] = None,\n", + " ) -> str:\n", " if stop is not None:\n", " raise ValueError(\"stop kwargs are not permitted.\")\n", " return prompt[:self.n]\n", @@ -67,7 +74,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "id": "10e5ece6", "metadata": {}, "outputs": [], @@ -77,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "id": "8cd49199", "metadata": {}, "outputs": [ @@ -87,7 +94,7 @@ "'This is a '" ] }, - "execution_count": 4, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -106,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "id": "9c33fa19", "metadata": {}, "outputs": [ diff --git a/docs/modules/models/llms/examples/streaming_llm.ipynb b/docs/modules/models/llms/examples/streaming_llm.ipynb index c48d1ee5..e10a79d7 100644 --- a/docs/modules/models/llms/examples/streaming_llm.ipynb +++ b/docs/modules/models/llms/examples/streaming_llm.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "4ac0ff54-540a-4f2b-8d9a-b590fec7fe07", "metadata": { "tags": [] @@ -21,14 +21,13 @@ "source": [ "from langchain.llms import OpenAI, Anthropic\n", "from langchain.chat_models import ChatOpenAI\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "from langchain.schema import HumanMessage" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "77f60a4b-f786-41f2-972e-e5bb8a48dcd5", "metadata": { "tags": [] @@ -79,7 +78,7 @@ } ], "source": [ - "llm = OpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "llm = OpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = llm(\"Write me a song about sparkling water.\")" ] }, @@ -95,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "a35373f1-9ee6-4753-a343-5aee749b8527", "metadata": { "tags": [] @@ -136,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "22665f16-e05b-473c-a4bd-ad75744ea024", "metadata": { "tags": [] @@ -191,7 +190,7 @@ } ], "source": [ - "chat = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "chat = ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "resp = chat([HumanMessage(content=\"Write me a song about sparkling water.\")])" ] }, @@ -205,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "eadae4ba-9f21-4ec8-845d-dd43b0edc2dc", "metadata": { "tags": [] @@ -245,7 +244,7 @@ } ], "source": [ - "llm = Anthropic(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)\n", + "llm = Anthropic(streaming=True, callbacks=[StreamingStdOutCallbackHandler()], temperature=0)\n", "llm(\"Write me a song about sparkling water.\")" ] } diff --git a/docs/modules/models/llms/integrations/gpt4all.ipynb b/docs/modules/models/llms/integrations/gpt4all.ipynb index 81083afc..73bbd9b9 100644 --- a/docs/modules/models/llms/integrations/gpt4all.ipynb +++ b/docs/modules/models/llms/integrations/gpt4all.ipynb @@ -40,7 +40,6 @@ "source": [ "from langchain import PromptTemplate, LLMChain\n", "from langchain.llms import GPT4All\n", - "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler" ] }, @@ -124,9 +123,9 @@ "outputs": [], "source": [ "# Callbacks support token-wise streaming\n", - "callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n", + "callbacks = [StreamingStdOutCallbackHandler()]\n", "# Verbose is required to pass to the callback manager\n", - "llm = GPT4All(model=local_path, callback_manager=callback_manager, verbose=True)" + "llm = GPT4All(model=local_path, callbacks=callbacks, verbose=True)" ] }, { diff --git a/docs/tracing.md b/docs/tracing.md index 3214140e..59128208 100644 --- a/docs/tracing.md +++ b/docs/tracing.md @@ -6,16 +6,15 @@ First, you should install tracing and set up your environment properly. You can use either a locally hosted version of this (uses Docker) or a cloud hosted version (in closed alpha). If you're interested in using the hosted platform, please fill out the form [here](https://forms.gle/tRCEMSeopZf6TE3b6). - - [Locally Hosted Setup](./tracing/local_installation.md) - [Cloud Hosted Setup](./tracing/hosted_installation.md) ## Tracing Walkthrough -When you first access the UI, you should see a page with your tracing sessions. -An initial one "default" should already be created for you. -A session is just a way to group traces together. -If you click on a session, it will take you to a page with no recorded traces that says "No Runs." +When you first access the UI, you should see a page with your tracing sessions. +An initial one "default" should already be created for you. +A session is just a way to group traces together. +If you click on a session, it will take you to a page with no recorded traces that says "No Runs." You can create a new session with the new session form. ![](tracing/homepage.png) @@ -35,7 +34,7 @@ We can keep on clicking further and further down to explore deeper and deeper. ![](tracing/explore.png) -We can also click on the "Explore" button of the top level run to dive even deeper. +We can also click on the "Explore" button of the top level run to dive even deeper. Here, we can see the inputs and outputs in full, as well as all the nested traces. ![](tracing/explore_trace.png) @@ -46,11 +45,12 @@ For example, here is the lowest level trace with the exact inputs/outputs to the ![](tracing/explore_llm.png) ## Changing Sessions + 1. To initially record traces to a session other than `"default"`, you can set the `LANGCHAIN_SESSION` environment variable to the name of the session you want to record to: ```python import os -os.environ["LANGCHAIN_HANDLER"] = "langchain" +os.environ["LANGCHAIN_TRACING"] = "true" os.environ["LANGCHAIN_SESSION"] = "my_session" # Make sure this session actually exists. You can create a new session in the UI. ``` diff --git a/docs/tracing/agent_with_tracing.ipynb b/docs/tracing/agent_with_tracing.ipynb index 26b2b9d0..7facae95 100644 --- a/docs/tracing/agent_with_tracing.ipynb +++ b/docs/tracing/agent_with_tracing.ipynb @@ -5,7 +5,14 @@ "id": "5371a9bb", "metadata": {}, "source": [ - "# Tracing Walkthrough" + "# Tracing Walkthrough\n", + "\n", + "There are two recommended ways to trace your LangChains:\n", + "\n", + "1. Setting the `LANGCHAIN_TRACING` environment variable to \"true\".\n", + "1. Using a context manager with tracing_enabled() to trace a particular block of code.\n", + "\n", + "**Note** if the environment variable is set, all code will be traced, regardless of whether or not it's within the context manager." ] }, { @@ -18,24 +25,22 @@ "outputs": [], "source": [ "import os\n", - "os.environ[\"LANGCHAIN_HANDLER\"] = \"langchain\"\n", - "\n", - "## Uncomment this if using hosted setup.\n", + "os.environ[\"LANGCHAIN_TRACING\"] = \"true\"\n", "\n", + "## Uncomment below if using hosted setup.\n", "# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchain-api-gateway-57eoxz8z.uc.gateway.dev\" \n", "\n", - "## Uncomment this if you want traces to be recorded to \"my_session\" instead of default.\n", - "\n", + "## Uncomment below if you want traces to be recorded to \"my_session\" instead of \"default\".\n", "# os.environ[\"LANGCHAIN_SESSION\"] = \"my_session\" \n", "\n", "## Better to set this environment variable in the terminal\n", - "## Uncomment this if using hosted version. Replace \"my_api_key\" with your actual API Key.\n", - "\n", + "## Uncomment below if using hosted version. Replace \"my_api_key\" with your actual API Key.\n", "# os.environ[\"LANGCHAIN_API_KEY\"] = \"my_api_key\" \n", "\n", "import langchain\n", "from langchain.agents import Tool, initialize_agent, load_tools\n", "from langchain.agents import AgentType\n", + "from langchain.callbacks import tracing_enabled\n", "from langchain.chat_models import ChatOpenAI\n", "from langchain.llms import OpenAI" ] @@ -73,8 +78,7 @@ "\u001b[32;1m\u001b[1;3m I need to use a calculator to solve this.\n", "Action: Calculator\n", "Action Input: 2^.123243\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\n", - "\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", "Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n", "Final Answer: 1.0891804557407723\u001b[0m\n", "\n", @@ -104,7 +108,9 @@ "cell_type": "code", "execution_count": 4, "id": "4829eb1d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", @@ -113,52 +119,11 @@ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", - "\u001b[32;1m\u001b[1;3mQuestion: What is 2 raised to .123243 power?\n", - "Thought: I need a calculator to solve this problem.\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"calculator\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\u001b[0m\n", - "Observation: calculator is not a valid tool, try another one.\n", - "\u001b[32;1m\u001b[1;3mI made a mistake, I need to use the correct tool for this question.\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"calculator\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\n", - "\u001b[0m\n", - "Observation: calculator is not a valid tool, try another one.\n", - "\u001b[32;1m\u001b[1;3mI made a mistake, the tool name is actually \"calc\" instead of \"calculator\".\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"calc\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\n", - "\u001b[0m\n", - "Observation: calc is not a valid tool, try another one.\n", - "\u001b[32;1m\u001b[1;3mI made another mistake, the tool name is actually \"Calculator\" instead of \"calc\".\n", - "Action:\n", - "```\n", - "{\n", - " \"action\": \"Calculator\",\n", - " \"action_input\": \"2^0.123243\"\n", - "}\n", - "```\n", - "\n", - "\u001b[0m\n", - "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\n", - "\u001b[0m\n", - "Thought:\u001b[32;1m\u001b[1;3mThe final answer is 1.0891804557407723.\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", "Final Answer: 1.0891804557407723\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n" @@ -186,8 +151,182 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "76abfd82", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.0891804557407723\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 5 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2193914912400514\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.2193914912400514\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + } + ], + "source": [ + "# Both of the agent runs will be traced because the environment variable is set\n", + "agent.run(\"What is 2 raised to .123243 power?\")\n", + "with tracing_enabled() as session:\n", + " agent.run(\"What is 5 raised to .123243 power?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fe833c33-033f-4806-be0c-cc3d147db13d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 5 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.2193914912400514\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.2193914912400514\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2 ^ .123243\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0891804557407723\u001b[0m\n", + "Thought:\u001b[32;1m\u001b[1;3mI now know the answer to the question. \n", + "Final Answer: 1.0891804557407723\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'1.0891804557407723'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now, we unset the environment variable and use a context manager.\n", + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", + "\n", + "# here, we are writing traces to \"my_test_session\"\n", + "with tracing_enabled(\"my_session\") as session:\n", + " assert session\n", + " agent.run(\"What is 5 raised to .123243 power?\") # this should be traced\n", + "\n", + "agent.run(\"What is 2 raised to .123243 power?\") # this should not be traced" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b34105a4-be8e-46e4-8abe-01adba3ba727", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\n", + "\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 3^0.123\u001b[0m\u001b[32;1m\u001b[1;3mI need to use a calculator to solve this.\n", + "Action: Calculator\n", + "Action Input: 2^0.123\u001b[0m\u001b[32;1m\u001b[1;3mAny number raised to the power of 0 is 1, but I'm not sure about a decimal power.\n", + "Action: Calculator\n", + "Action Input: 1^.123\u001b[0m\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.1446847956963533\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0889970153361064\u001b[0m\n", + "Thought:\n", + "Observation: \u001b[36;1m\u001b[1;3mAnswer: 1.0\u001b[0m\n", + "Thought:\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'1.0'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The context manager is concurrency safe:\n", + "import asyncio \n", + "if \"LANGCHAIN_TRACING\" in os.environ:\n", + " del os.environ[\"LANGCHAIN_TRACING\"]\n", + " \n", + "questions = [f\"What is {i} raised to .123 power?\" for i in range(1,4)]\n", + "\n", + "# start a background task\n", + "task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced\n", + "with tracing_enabled() as session:\n", + " assert session\n", + " tasks = [agent.arun(q) for q in questions[1:3]] # these should be traced\n", + " await asyncio.gather(*tasks)\n", + "\n", + "await task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e46c85b-2ac0-4661-abed-9c2bf3036820", "metadata": {}, "outputs": [], "source": [] @@ -209,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/langchain/__init__.py b/langchain/__init__.py index ca9fe387..c2513b28 100644 --- a/langchain/__init__.py +++ b/langchain/__init__.py @@ -5,11 +5,6 @@ from typing import Optional from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.cache import BaseCache -from langchain.callbacks import ( - set_default_callback_manager, - set_handler, - set_tracing_callback_manager, -) from langchain.chains import ( ConversationChain, LLMBashChain, @@ -67,7 +62,6 @@ del metadata # optional, avoids polluting the results of dir(__package__) verbose: bool = False llm_cache: Optional[BaseCache] = None -set_default_callback_manager() # For backwards compatibility SerpAPIChain = SerpAPIWrapper @@ -119,7 +113,5 @@ __all__ = [ "VectorDBQAWithSourcesChain", "QAWithSourcesChain", "PALChain", - "set_handler", - "set_tracing_callback_manager", "LlamaCpp", ] diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 0099a035..2ac0d539 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -13,7 +13,13 @@ import yaml from pydantic import BaseModel, root_validator from langchain.agents.tools import InvalidTool +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + Callbacks, +) from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.input import get_color_mapping @@ -23,7 +29,6 @@ from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( AgentAction, AgentFinish, - BaseLanguageModel, BaseMessage, BaseOutputParser, ) @@ -46,13 +51,17 @@ class BaseSingleActionAgent(BaseModel): @abstractmethod def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -61,13 +70,17 @@ class BaseSingleActionAgent(BaseModel): @abstractmethod async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -170,13 +183,17 @@ class BaseMultiActionAgent(BaseModel): @abstractmethod def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[List[AgentAction], AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -185,13 +202,17 @@ class BaseMultiActionAgent(BaseModel): @abstractmethod async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[List[AgentAction], AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: @@ -285,38 +306,52 @@ class LLMSingleActionAgent(BaseSingleActionAgent): return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ output = self.llm_chain.run( - intermediate_steps=intermediate_steps, stop=self.stop, **kwargs + intermediate_steps=intermediate_steps, + stop=self.stop, + callbacks=callbacks, + **kwargs, ) return self.output_parser.parse(output) async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ output = await self.llm_chain.arun( - intermediate_steps=intermediate_steps, stop=self.stop, **kwargs + intermediate_steps=intermediate_steps, + stop=self.stop, + callbacks=callbacks, + **kwargs, ) return self.output_parser.parse(output) @@ -368,37 +403,45 @@ class Agent(BaseSingleActionAgent): return thoughts def plan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - full_output = self.llm_chain.predict(**full_inputs) + full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) return self.output_parser.parse(full_output) async def aplan( - self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: """Given input, decided what to do. Args: intermediate_steps: Steps the LLM has taken to date, along with observations + callbacks: Callbacks to run. **kwargs: User inputs. Returns: Action specifying what tool to use. """ full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - full_output = await self.llm_chain.apredict(**full_inputs) + full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs) return self.output_parser.parse(full_output) def get_full_inputs( @@ -636,24 +679,27 @@ class AgentExecutor(Chain): return True - def _return(self, output: AgentFinish, intermediate_steps: list) -> Dict[str, Any]: - self.callback_manager.on_agent_finish( - output, color="green", verbose=self.verbose - ) + def _return( + self, + output: AgentFinish, + intermediate_steps: list, + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + if run_manager: + run_manager.on_agent_finish(output, color="green", verbose=self.verbose) final_output = output.return_values if self.return_intermediate_steps: final_output["intermediate_steps"] = intermediate_steps return final_output async def _areturn( - self, output: AgentFinish, intermediate_steps: list + self, + output: AgentFinish, + intermediate_steps: list, + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Dict[str, Any]: - if self.callback_manager.is_async: - await self.callback_manager.on_agent_finish( - output, color="green", verbose=self.verbose - ) - else: - self.callback_manager.on_agent_finish( + if run_manager: + await run_manager.on_agent_finish( output, color="green", verbose=self.verbose ) final_output = output.return_values @@ -667,13 +713,18 @@ class AgentExecutor(Chain): color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], + run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: """Take a single step in the thought-action-observation loop. Override this to take control of how the agent makes and acts on choices. """ # Call the LLM to see what to do. - output = self.agent.plan(intermediate_steps, **inputs) + output = self.agent.plan( + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return output @@ -684,9 +735,8 @@ class AgentExecutor(Chain): actions = output result = [] for agent_action in actions: - self.callback_manager.on_agent_action( - agent_action, verbose=self.verbose, color="green" - ) + if run_manager: + run_manager.on_agent_action(agent_action, color="green") # Otherwise we lookup the tool if agent_action.tool in name_to_tool_map: tool = name_to_tool_map[agent_action.tool] @@ -700,6 +750,7 @@ class AgentExecutor(Chain): agent_action.tool_input, verbose=self.verbose, color=color, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) else: @@ -708,6 +759,7 @@ class AgentExecutor(Chain): agent_action.tool, verbose=self.verbose, color=None, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) result.append((agent_action, observation)) @@ -719,13 +771,18 @@ class AgentExecutor(Chain): color_mapping: Dict[str, str], inputs: Dict[str, str], intermediate_steps: List[Tuple[AgentAction, str]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: """Take a single step in the thought-action-observation loop. Override this to take control of how the agent makes and acts on choices. """ # Call the LLM to see what to do. - output = await self.agent.aplan(intermediate_steps, **inputs) + output = await self.agent.aplan( + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) # If the tool chosen is the finishing tool, then we end and return. if isinstance(output, AgentFinish): return output @@ -738,12 +795,8 @@ class AgentExecutor(Chain): async def _aperform_agent_action( agent_action: AgentAction, ) -> Tuple[AgentAction, str]: - if self.callback_manager.is_async: - await self.callback_manager.on_agent_action( - agent_action, verbose=self.verbose, color="green" - ) - else: - self.callback_manager.on_agent_action( + if run_manager: + await run_manager.on_agent_action( agent_action, verbose=self.verbose, color="green" ) # Otherwise we lookup the tool @@ -759,6 +812,7 @@ class AgentExecutor(Chain): agent_action.tool_input, verbose=self.verbose, color=color, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) else: @@ -767,6 +821,7 @@ class AgentExecutor(Chain): agent_action.tool, verbose=self.verbose, color=None, + callbacks=run_manager.get_child() if run_manager else None, **tool_run_kwargs, ) return agent_action, observation @@ -778,7 +833,11 @@ class AgentExecutor(Chain): return list(result) - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run text through and get agent response.""" # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool for tool in self.tools} @@ -794,10 +853,16 @@ class AgentExecutor(Chain): # We now enter the agent loop (until it returns something). while self._should_continue(iterations, time_elapsed): next_step_output = self._take_next_step( - name_to_tool_map, color_mapping, inputs, intermediate_steps + name_to_tool_map, + color_mapping, + inputs, + intermediate_steps, + run_manager=run_manager, ) if isinstance(next_step_output, AgentFinish): - return self._return(next_step_output, intermediate_steps) + return self._return( + next_step_output, intermediate_steps, run_manager=run_manager + ) intermediate_steps.extend(next_step_output) if len(next_step_output) == 1: @@ -813,7 +878,11 @@ class AgentExecutor(Chain): ) return self._return(output, intermediate_steps) - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Run text through and get agent response.""" # Construct a mapping of tool name to tool for easy lookup name_to_tool_map = {tool.name: tool for tool in self.tools} @@ -831,7 +900,11 @@ class AgentExecutor(Chain): try: while self._should_continue(iterations, time_elapsed): next_step_output = await self._atake_next_step( - name_to_tool_map, color_mapping, inputs, intermediate_steps + name_to_tool_map, + color_mapping, + inputs, + intermediate_steps, + run_manager=run_manager, ) if isinstance(next_step_output, AgentFinish): return await self._areturn(next_step_output, intermediate_steps) @@ -849,7 +922,9 @@ class AgentExecutor(Chain): output = self.agent.return_stopped_response( self.early_stopping_method, intermediate_steps, **inputs ) - return await self._areturn(output, intermediate_steps) + return await self._areturn( + output, intermediate_steps, run_manager=run_manager + ) except TimeoutError: # stop early when interrupted by the async timeout output = self.agent.return_stopped_response( diff --git a/langchain/agents/agent_toolkits/file_management/toolkit.py b/langchain/agents/agent_toolkits/file_management/toolkit.py index 17ae4f3a..cc7d77f7 100644 --- a/langchain/agents/agent_toolkits/file_management/toolkit.py +++ b/langchain/agents/agent_toolkits/file_management/toolkit.py @@ -54,7 +54,7 @@ class FileManagementToolkit(BaseToolkit): tools: List[BaseTool] = [] for tool in allowed_tools: tool_cls = _FILE_TOOLS[tool] - tools.append(tool_cls(root_dir=self.root_dir)) + tools.append(tool_cls(root_dir=self.root_dir)) # type: ignore return tools diff --git a/langchain/agents/agent_toolkits/openapi/planner.py b/langchain/agents/agent_toolkits/openapi/planner.py index 7fada246..a6d9fd09 100644 --- a/langchain/agents/agent_toolkits/openapi/planner.py +++ b/langchain/agents/agent_toolkits/openapi/planner.py @@ -28,13 +28,13 @@ from langchain.agents.agent_toolkits.openapi.planner_prompt import ( from langchain.agents.agent_toolkits.openapi.spec import ReducedOpenAPISpec from langchain.agents.mrkl.base import ZeroShotAgent from langchain.agents.tools import Tool +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI from langchain.memory import ReadOnlySharedMemory from langchain.prompts import PromptTemplate from langchain.prompts.base import BasePromptTemplate from langchain.requests import RequestsWrapper -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.requests.tool import BaseRequestsTool diff --git a/langchain/agents/agent_toolkits/powerbi/toolkit.py b/langchain/agents/agent_toolkits/powerbi/toolkit.py index 00056563..812e3d01 100644 --- a/langchain/agents/agent_toolkits/powerbi/toolkit.py +++ b/langchain/agents/agent_toolkits/powerbi/toolkit.py @@ -4,10 +4,10 @@ from typing import List, Optional from pydantic import Field from langchain.agents.agent_toolkits.base import BaseToolkit +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel from langchain.tools import BaseTool from langchain.tools.powerbi.prompt import QUESTION_TO_QUERY from langchain.tools.powerbi.tool import ( diff --git a/langchain/agents/agent_toolkits/sql/toolkit.py b/langchain/agents/agent_toolkits/sql/toolkit.py index 91c3de0b..085d24e3 100644 --- a/langchain/agents/agent_toolkits/sql/toolkit.py +++ b/langchain/agents/agent_toolkits/sql/toolkit.py @@ -4,7 +4,7 @@ from typing import List from pydantic import Field from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel from langchain.sql_database import SQLDatabase from langchain.tools import BaseTool from langchain.tools.sql_database.tool import ( diff --git a/langchain/agents/chat/base.py b/langchain/agents/chat/base.py index 7245c10d..04ceca71 100644 --- a/langchain/agents/chat/base.py +++ b/langchain/agents/chat/base.py @@ -5,6 +5,7 @@ from pydantic import Field from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.chat.output_parser import ChatOutputParser from langchain.agents.chat.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate @@ -13,7 +14,7 @@ from langchain.prompts.chat import ( HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.schema import AgentAction, BaseLanguageModel +from langchain.schema import AgentAction from langchain.tools import BaseTool diff --git a/langchain/agents/conversational/base.py b/langchain/agents/conversational/base.py index 75018314..16a43a90 100644 --- a/langchain/agents/conversational/base.py +++ b/langchain/agents/conversational/base.py @@ -9,10 +9,10 @@ from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent_types import AgentType from langchain.agents.conversational.output_parser import ConvoOutputParser from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/langchain/agents/conversational_chat/base.py b/langchain/agents/conversational_chat/base.py index a91915c0..d9b83ecc 100644 --- a/langchain/agents/conversational_chat/base.py +++ b/langchain/agents/conversational_chat/base.py @@ -12,6 +12,7 @@ from langchain.agents.conversational_chat.prompt import ( SUFFIX, TEMPLATE_TOOL_RESPONSE, ) +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.prompts.base import BasePromptTemplate @@ -24,7 +25,6 @@ from langchain.prompts.chat import ( from langchain.schema import ( AgentAction, AIMessage, - BaseLanguageModel, BaseMessage, BaseOutputParser, HumanMessage, diff --git a/langchain/agents/initialize.py b/langchain/agents/initialize.py index 72784b89..9a52b151 100644 --- a/langchain/agents/initialize.py +++ b/langchain/agents/initialize.py @@ -4,8 +4,8 @@ from typing import Any, Optional, Sequence from langchain.agents.agent import AgentExecutor from langchain.agents.agent_types import AgentType from langchain.agents.loading import AGENT_TO_CLASS, load_agent +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index 399e6b6c..e11440e8 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -103,8 +103,8 @@ def _get_llm_math(llm: BaseLLM) -> BaseTool: return Tool( name="Calculator", description="Useful for when you need to answer questions about math.", - func=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).run, - coroutine=LLMMathChain(llm=llm, callback_manager=llm.callback_manager).arun, + func=LLMMathChain.from_llm(llm=llm).run, + coroutine=LLMMathChain.from_llm(llm=llm).arun, ) diff --git a/langchain/agents/mrkl/base.py b/langchain/agents/mrkl/base.py index 60f7c981..8ce4411c 100644 --- a/langchain/agents/mrkl/base.py +++ b/langchain/agents/mrkl/base.py @@ -10,10 +10,10 @@ from langchain.agents.agent_types import AgentType from langchain.agents.mrkl.output_parser import MRKLOutputParser from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS, PREFIX, SUFFIX from langchain.agents.tools import Tool +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains import LLMChain from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool diff --git a/langchain/agents/tools.py b/langchain/agents/tools.py index 7a6637c2..0f943138 100644 --- a/langchain/agents/tools.py +++ b/langchain/agents/tools.py @@ -1,9 +1,14 @@ """Interface for tools.""" from functools import partial +from inspect import signature from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union from pydantic import BaseModel, validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool, StructuredTool @@ -44,14 +49,44 @@ class Tool(BaseTool): ) return tuple(all_args), {} - def _run(self, *args: Any, **kwargs: Any) -> Any: + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool.""" - return self.func(*args, **kwargs) + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) - async def _arun(self, *args: Any, **kwargs: Any) -> Any: + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool asynchronously.""" if self.coroutine: - return await self.coroutine(*args, **kwargs) + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) raise NotImplementedError("Tool does not support async") # TODO: this is for backwards compatibility, remove in future @@ -70,11 +105,17 @@ class InvalidTool(BaseTool): name = "invalid_tool" description = "Called when tool name is invalid." - def _run(self, tool_name: str) -> str: + def _run( + self, tool_name: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Use the tool.""" return f"{tool_name} is not a valid tool, try another one." - async def _arun(self, tool_name: str) -> str: + async def _arun( + self, + tool_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return f"{tool_name} is not a valid tool, try another one." diff --git a/langchain/base_language.py b/langchain/base_language.py new file mode 100644 index 00000000..3c524ef4 --- /dev/null +++ b/langchain/base_language.py @@ -0,0 +1,60 @@ +"""Base class for all language models.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List, Optional + +from pydantic import BaseModel + +from langchain.callbacks.manager import Callbacks +from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string + + +def _get_num_tokens_default_method(text: str) -> int: + """Get the number of tokens present in the text.""" + # TODO: this method may not be exact. + # TODO: this method may differ based on model (eg codex). + try: + from transformers import GPT2TokenizerFast + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "This is needed in order to calculate get_num_tokens. " + "Please install it with `pip install transformers`." + ) + # create a GPT-2 tokenizer instance + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + + # tokenize the text using the GPT-3 tokenizer + tokenized_text = tokenizer.tokenize(text) + + # calculate the number of tokens in the tokenized text + return len(tokenized_text) + + +class BaseLanguageModel(BaseModel, ABC): + @abstractmethod + def generate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + ) -> LLMResult: + """Take in a list of prompt values and return an LLMResult.""" + + @abstractmethod + async def agenerate_prompt( + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + ) -> LLMResult: + """Take in a list of prompt values and return an LLMResult.""" + + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text.""" + return _get_num_tokens_default_method(text) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the message.""" + return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index c6137bf9..a85c375e 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -1,80 +1,19 @@ """Callback handlers that allow listening to events in LangChain.""" -import os -from contextlib import contextmanager -from typing import Generator, Optional from langchain.callbacks.aim_callback import AimCallbackHandler -from langchain.callbacks.base import ( - AsyncCallbackManager, - BaseCallbackHandler, - BaseCallbackManager, - CallbackManager, -) from langchain.callbacks.clearml_callback import ClearMLCallbackHandler from langchain.callbacks.comet_ml_callback import CometCallbackHandler +from langchain.callbacks.manager import ( + get_openai_callback, + tracing_enabled, +) from langchain.callbacks.openai_info import OpenAICallbackHandler -from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler -from langchain.callbacks.tracers import SharedLangChainTracer from langchain.callbacks.wandb_callback import WandbCallbackHandler - -def get_callback_manager() -> BaseCallbackManager: - """Return the shared callback manager.""" - return SharedCallbackManager() - - -def set_handler(handler: BaseCallbackHandler) -> None: - """Set handler.""" - callback = get_callback_manager() - callback.set_handler(handler) - - -def set_default_callback_manager() -> None: - """Set default callback manager.""" - default_handler = os.environ.get("LANGCHAIN_HANDLER", "stdout") - if default_handler == "stdout": - set_handler(StdOutCallbackHandler()) - elif default_handler == "langchain": - session = os.environ.get("LANGCHAIN_SESSION") - set_tracing_callback_manager(session) - else: - raise ValueError( - f"LANGCHAIN_HANDLER should be one of `stdout` " - f"or `langchain`, got {default_handler}" - ) - - -def set_tracing_callback_manager(session_name: Optional[str] = None) -> None: - """Set tracing callback manager.""" - handler = SharedLangChainTracer() - callback = get_callback_manager() - callback.set_handlers([handler, StdOutCallbackHandler()]) - if session_name is None: - handler.load_default_session() - else: - try: - handler.load_session(session_name) - except Exception: - raise ValueError(f"session {session_name} not found") - - -@contextmanager -def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: - """Get OpenAI callback handler in a context manager.""" - handler = OpenAICallbackHandler() - manager = get_callback_manager() - manager.add_handler(handler) - yield handler - manager.remove_handler(handler) - - __all__ = [ - "CallbackManager", - "AsyncCallbackManager", "OpenAICallbackHandler", - "SharedCallbackManager", "StdOutCallbackHandler", "AimCallbackHandler", "WandbCallbackHandler", @@ -82,8 +21,5 @@ __all__ = [ "CometCallbackHandler", "AsyncIteratorCallbackHandler", "get_openai_callback", - "set_tracing_callback_manager", - "set_default_callback_manager", - "set_handler", - "get_callback_manager", + "tracing_enabled", ] diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 5b47b82a..69ba0f1a 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -1,19 +1,174 @@ -"""Base callback handler that can be used to handle callbacks from langchain.""" -import asyncio -import functools -from abc import ABC, abstractmethod +"""Base callback handler that can be used to handle callbacks in langchain.""" +from __future__ import annotations + +import copy from typing import Any, Dict, List, Optional, Union +from uuid import UUID from langchain.schema import AgentAction, AgentFinish, LLMResult -class BaseCallbackHandler(ABC): - """Base callback handler that can be used to handle callbacks from langchain.""" +class LLMManagerMixin: + """Mixin for LLM callbacks.""" - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return False + def on_llm_new_token( + self, + token: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on new LLM token. Only available when streaming is enabled.""" + + def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM ends running.""" + + def on_llm_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM errors.""" + + +class ChainManagerMixin: + """Mixin for chain callbacks.""" + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain ends running.""" + + def on_chain_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain errors.""" + + def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent action.""" + + def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on agent end.""" + + +class ToolManagerMixin: + """Mixin for tool callbacks.""" + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool ends running.""" + + def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool errors.""" + + +class CallbackManagerMixin: + """Mixin for callback manager.""" + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when LLM starts running.""" + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when chain starts running.""" + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when tool starts running.""" + + +class RunManagerMixin: + """Mixin for run manager.""" + + def on_text( + self, + text: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run on arbitrary text.""" + + +class BaseCallbackHandler( + LLMManagerMixin, + ChainManagerMixin, + ToolManagerMixin, + CallbackManagerMixin, + RunManagerMixin, +): + """Base callback handler that can be used to handle callbacks from langchain.""" @property def ignore_llm(self) -> bool: @@ -30,480 +185,197 @@ class BaseCallbackHandler(ABC): """Whether to ignore agent callbacks.""" return False - @abstractmethod - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> Any: + +class AsyncCallbackHandler(BaseCallbackHandler): + """Async callback handler that can be used to handle callbacks from langchain.""" + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when LLM starts running.""" - @abstractmethod - def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: + async def on_llm_new_token( + self, + token: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" - @abstractmethod - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: + async def on_llm_end( + self, + response: LLMResult, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when LLM ends running.""" - @abstractmethod - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + async def on_llm_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when LLM errors.""" - @abstractmethod - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when chain starts running.""" - @abstractmethod - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: + async def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when chain ends running.""" - @abstractmethod - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + async def on_chain_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when chain errors.""" - @abstractmethod - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when tool starts running.""" - @abstractmethod - def on_tool_end(self, output: str, **kwargs: Any) -> Any: + async def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when tool ends running.""" - @abstractmethod - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: + async def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run when tool errors.""" - @abstractmethod - def on_text(self, text: str, **kwargs: Any) -> Any: + async def on_text( + self, + text: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run on arbitrary text.""" - @abstractmethod - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + async def on_agent_action( + self, + action: AgentAction, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run on agent action.""" - @abstractmethod - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + async def on_agent_finish( + self, + finish: AgentFinish, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: """Run on agent end.""" -class BaseCallbackManager(BaseCallbackHandler, ABC): +class BaseCallbackManager(CallbackManagerMixin): """Base callback manager that can be used to handle callbacks from LangChain.""" + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, + parent_run_id: Optional[UUID] = None, + ) -> None: + """Initialize callback manager.""" + self.handlers: List[BaseCallbackHandler] = handlers + self.inheritable_handlers: List[BaseCallbackHandler] = ( + inheritable_handlers or [] + ) + self.parent_run_id: Optional[UUID] = parent_run_id + @property def is_async(self) -> bool: """Whether the callback manager is async.""" return False - @abstractmethod - def add_handler(self, callback: BaseCallbackHandler) -> None: + def add_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Add a handler to the callback manager.""" + self.handlers.append(handler) + if inherit: + self.inheritable_handlers.append(handler) - @abstractmethod def remove_handler(self, handler: BaseCallbackHandler) -> None: """Remove a handler from the callback manager.""" + self.handlers.remove(handler) + self.inheritable_handlers.remove(handler) - def set_handler(self, handler: BaseCallbackHandler) -> None: + def set_handlers( + self, handlers: List[BaseCallbackHandler], inherit: bool = True + ) -> None: + """Set handlers as the only handlers on the callback manager.""" + self.handlers = [] + self.inheritable_handlers = [] + for handler in handlers: + self.add_handler(handler, inherit=inherit) + + def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: """Set handler as the only handler on the callback manager.""" - self.set_handlers([handler]) + self.set_handlers([handler], inherit=inherit) - @abstractmethod - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: - """Set handlers as the only handlers on the callback manager.""" - - -class CallbackManager(BaseCallbackManager): - """Callback manager that can be used to handle callbacks from langchain.""" - - def __init__(self, handlers: List[BaseCallbackHandler]) -> None: - """Initialize callback manager.""" - self.handlers: List[BaseCallbackHandler] = handlers - - def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_start(serialized, prompts, **kwargs) - - def on_llm_new_token( - self, token: str, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when LLM generates a new token.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_new_token(token, **kwargs) - - def on_llm_end( - self, response: LLMResult, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when LLM ends running.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_end(response) - - def on_llm_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when LLM errors.""" - for handler in self.handlers: - if not handler.ignore_llm: - if verbose or handler.always_verbose: - handler.on_llm_error(error) - - def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when chain starts running.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - handler.on_chain_start(serialized, inputs, **kwargs) - - def on_chain_end( - self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any - ) -> None: - """Run when chain ends running.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - handler.on_chain_end(outputs) - - def on_chain_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when chain errors.""" - for handler in self.handlers: - if not handler.ignore_chain: - if verbose or handler.always_verbose: - handler.on_chain_error(error) - - def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when tool starts running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_tool_start(serialized, input_str, **kwargs) - - def on_agent_action( - self, action: AgentAction, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_agent_action(action, **kwargs) - - def on_tool_end(self, output: str, verbose: bool = False, **kwargs: Any) -> None: - """Run when tool ends running.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_tool_end(output, **kwargs) - - def on_tool_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when tool errors.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_tool_error(error) - - def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: - """Run on additional input from chains and agents.""" - for handler in self.handlers: - if verbose or handler.always_verbose: - handler.on_text(text, **kwargs) - - def on_agent_finish( - self, finish: AgentFinish, verbose: bool = False, **kwargs: Any - ) -> None: - """Run on agent end.""" - for handler in self.handlers: - if not handler.ignore_agent: - if verbose or handler.always_verbose: - handler.on_agent_finish(finish, **kwargs) - - def add_handler(self, handler: BaseCallbackHandler) -> None: - """Add a handler to the callback manager.""" - self.handlers.append(handler) - - def remove_handler(self, handler: BaseCallbackHandler) -> None: - """Remove a handler from the callback manager.""" - self.handlers.remove(handler) - - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: - """Set handlers as the only handlers on the callback manager.""" - self.handlers = handlers - - -class AsyncCallbackHandler(BaseCallbackHandler): - """Async callback handler that can be used to handle callbacks from langchain.""" - - async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - - async def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" - - async def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" - - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - - async def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when chain errors.""" - - async def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - - async def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - - async def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when tool errors.""" - - async def on_text(self, text: str, **kwargs: Any) -> None: - """Run on arbitrary text.""" - - async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None: - """Run on agent action.""" - - async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run on agent end.""" - - -async def _handle_event_for_handler( - handler: BaseCallbackHandler, - event_name: str, - ignore_condition_name: Optional[str], - verbose: bool, - *args: Any, - **kwargs: Any -) -> None: - if ignore_condition_name is None or not getattr(handler, ignore_condition_name): - if verbose or handler.always_verbose: - event = getattr(handler, event_name) - if asyncio.iscoroutinefunction(event): - await event(*args, **kwargs) - else: - await asyncio.get_event_loop().run_in_executor( - None, functools.partial(event, *args, **kwargs) - ) - - -class AsyncCallbackManager(BaseCallbackManager): - """Async callback manager that can be used to handle callbacks from LangChain.""" - - @property - def is_async(self) -> bool: - """Return whether the handler is async.""" - return True - - def __init__(self, handlers: List[BaseCallbackHandler]) -> None: - """Initialize callback manager.""" - self.handlers: List[BaseCallbackHandler] = handlers - - async def _handle_event( - self, - event_name: str, - ignore_condition_name: Optional[str], - verbose: bool, - *args: Any, - **kwargs: Any - ) -> None: - """Generic event handler for AsyncCallbackManager.""" - await asyncio.gather( - *( - _handle_event_for_handler( - handler, event_name, ignore_condition_name, verbose, *args, **kwargs - ) - for handler in self.handlers - ) + def __copy__(self) -> "BaseCallbackManager": + return self.__class__( + self.handlers.copy(), self.inheritable_handlers.copy(), self.parent_run_id ) - async def on_llm_start( - self, - serialized: Dict[str, Any], - prompts: List[str], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - await self._handle_event( - "on_llm_start", "ignore_llm", verbose, serialized, prompts, **kwargs + def __deepcopy__(self, memo: dict) -> "BaseCallbackManager": + return self.__class__( + [copy.deepcopy(handler, memo) for handler in self.handlers], + [copy.deepcopy(handler, memo) for handler in self.inheritable_handlers], + self.parent_run_id, ) - - async def on_llm_new_token( - self, token: str, verbose: bool = False, **kwargs: Any - ) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - await self._handle_event( - "on_llm_new_token", "ignore_llm", verbose, token, **kwargs - ) - - async def on_llm_end( - self, response: LLMResult, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when LLM ends running.""" - await self._handle_event( - "on_llm_end", "ignore_llm", verbose, response, **kwargs - ) - - async def on_llm_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when LLM errors.""" - await self._handle_event("on_llm_error", "ignore_llm", verbose, error, **kwargs) - - async def on_chain_start( - self, - serialized: Dict[str, Any], - inputs: Dict[str, Any], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when chain starts running.""" - await self._handle_event( - "on_chain_start", "ignore_chain", verbose, serialized, inputs, **kwargs - ) - - async def on_chain_end( - self, outputs: Dict[str, Any], verbose: bool = False, **kwargs: Any - ) -> None: - """Run when chain ends running.""" - await self._handle_event( - "on_chain_end", "ignore_chain", verbose, outputs, **kwargs - ) - - async def on_chain_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when chain errors.""" - await self._handle_event( - "on_chain_error", "ignore_chain", verbose, error, **kwargs - ) - - async def on_tool_start( - self, - serialized: Dict[str, Any], - input_str: str, - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when tool starts running.""" - await self._handle_event( - "on_tool_start", "ignore_agent", verbose, serialized, input_str, **kwargs - ) - - async def on_tool_end( - self, output: str, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when tool ends running.""" - await self._handle_event( - "on_tool_end", "ignore_agent", verbose, output, **kwargs - ) - - async def on_tool_error( - self, - error: Union[Exception, KeyboardInterrupt], - verbose: bool = False, - **kwargs: Any - ) -> None: - """Run when tool errors.""" - await self._handle_event( - "on_tool_error", "ignore_agent", verbose, error, **kwargs - ) - - async def on_text(self, text: str, verbose: bool = False, **kwargs: Any) -> None: - """Run when text is printed.""" - await self._handle_event("on_text", None, verbose, text, **kwargs) - - async def on_agent_action( - self, action: AgentAction, verbose: bool = False, **kwargs: Any - ) -> None: - """Run on agent action.""" - await self._handle_event( - "on_agent_action", "ignore_agent", verbose, action, **kwargs - ) - - async def on_agent_finish( - self, finish: AgentFinish, verbose: bool = False, **kwargs: Any - ) -> None: - """Run when agent finishes.""" - await self._handle_event( - "on_agent_finish", "ignore_agent", verbose, finish, **kwargs - ) - - def add_handler(self, handler: BaseCallbackHandler) -> None: - """Add a handler to the callback manager.""" - self.handlers.append(handler) - - def remove_handler(self, handler: BaseCallbackHandler) -> None: - """Remove a handler from the callback manager.""" - self.handlers.remove(handler) - - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: - """Set handlers as the only handlers on the callback manager.""" - self.handlers = handlers diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py new file mode 100644 index 00000000..a21847a5 --- /dev/null +++ b/langchain/callbacks/manager.py @@ -0,0 +1,736 @@ +from __future__ import annotations + +import asyncio +import copy +import functools +import os +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any, Dict, Generator, List, Optional, Type, TypeVar, Union +from uuid import UUID, uuid4 + +from langchain.callbacks.base import ( + BaseCallbackHandler, + BaseCallbackManager, + ChainManagerMixin, + LLMManagerMixin, + RunManagerMixin, + ToolManagerMixin, +) +from langchain.callbacks.openai_info import OpenAICallbackHandler +from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.callbacks.tracers.base import TracerSession +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.schema import AgentAction, AgentFinish, LLMResult + +Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] + +openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( + "openai_callback", default=None +) +tracing_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( + "tracing_callback", default=None +) + + +@contextmanager +def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]: + """Get OpenAI callback handler in a context manager.""" + cb = OpenAICallbackHandler() + openai_callback_var.set(cb) + yield cb + openai_callback_var.set(None) + + +@contextmanager +def tracing_enabled( + session_name: str = "default", +) -> Generator[TracerSession, None, None]: + """Get OpenAI callback handler in a context manager.""" + cb = LangChainTracer() + session = cb.load_session(session_name) + tracing_callback_var.set(cb) + yield session + tracing_callback_var.set(None) + + +def _handle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + for handler in handlers: + try: + if ignore_condition_name is None or not getattr( + handler, ignore_condition_name + ): + getattr(handler, event_name)(*args, **kwargs) + except Exception as e: + # TODO: switch this to use logging + print(f"Error in {event_name} callback: {e}") + + +async def _ahandle_event_for_handler( + handler: BaseCallbackHandler, + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + try: + if ignore_condition_name is None or not getattr(handler, ignore_condition_name): + event = getattr(handler, event_name) + if asyncio.iscoroutinefunction(event): + await event(*args, **kwargs) + else: + await asyncio.get_event_loop().run_in_executor( + None, functools.partial(event, *args, **kwargs) + ) + except Exception as e: + # TODO: switch this to use logging + print(f"Error in {event_name} callback: {e}") + + +async def _ahandle_event( + handlers: List[BaseCallbackHandler], + event_name: str, + ignore_condition_name: Optional[str], + *args: Any, + **kwargs: Any, +) -> None: + """Generic event handler for AsyncCallbackManager.""" + await asyncio.gather( + *( + _ahandle_event_for_handler( + handler, event_name, ignore_condition_name, *args, **kwargs + ) + for handler in handlers + ) + ) + + +BRM = TypeVar("BRM", bound="BaseRunManager") + + +class BaseRunManager(RunManagerMixin): + """Base class for run manager (a bound callback manager).""" + + def __init__( + self, + run_id: UUID, + handlers: List[BaseCallbackHandler], + inheritable_handlers: List[BaseCallbackHandler], + parent_run_id: Optional[UUID] = None, + ) -> None: + """Initialize run manager.""" + self.run_id = run_id + self.handlers = handlers + self.inheritable_handlers = inheritable_handlers + self.parent_run_id = parent_run_id + + @classmethod + def get_noop_manager(cls: Type[BRM]) -> BRM: + """Return a manager that doesn't perform any operations.""" + return cls(uuid4(), [], []) + + +class RunManager(BaseRunManager): + """Sync Run Manager.""" + + def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: + """Run when text is received.""" + _handle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class AsyncRunManager(BaseRunManager): + """Async Run Manager.""" + + async def on_text( + self, + text: str, + **kwargs: Any, + ) -> Any: + """Run when text is received.""" + await _ahandle_event( + self.handlers, + "on_text", + None, + text, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): + """Callback manager for LLM run.""" + + def on_llm_new_token( + self, + token: str, + **kwargs: Any, + ) -> None: + """Run when LLM generates a new token.""" + _handle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token=token, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + _handle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_llm_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + _handle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): + """Async callback manager for LLM run.""" + + async def on_llm_new_token( + self, + token: str, + **kwargs: Any, + ) -> None: + """Run when LLM generates a new token.""" + await _ahandle_event( + self.handlers, + "on_llm_new_token", + "ignore_llm", + token, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + await _ahandle_event( + self.handlers, + "on_llm_end", + "ignore_llm", + response, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_llm_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when LLM errors.""" + await _ahandle_event( + self.handlers, + "on_llm_error", + "ignore_llm", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class CallbackManagerForChainRun(RunManager, ChainManagerMixin): + """Callback manager for chain run.""" + + def get_child(self) -> CallbackManager: + """Get a child callback manager.""" + manager = CallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + _handle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_chain_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + _handle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received.""" + _handle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received.""" + _handle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): + """Async callback manager for chain run.""" + + def get_child(self) -> AsyncCallbackManager: + """Get a child callback manager.""" + manager = AsyncCallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + await _ahandle_event( + self.handlers, + "on_chain_end", + "ignore_chain", + outputs, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_chain_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when chain errors.""" + await _ahandle_event( + self.handlers, + "on_chain_error", + "ignore_chain", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run when agent action is received.""" + await _ahandle_event( + self.handlers, + "on_agent_action", + "ignore_agent", + action, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run when agent finish is received.""" + await _ahandle_event( + self.handlers, + "on_agent_finish", + "ignore_agent", + finish, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class CallbackManagerForToolRun(RunManager, ToolManagerMixin): + """Callback manager for tool run.""" + + def get_child(self) -> CallbackManager: + """Get a child callback manager.""" + manager = CallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + def on_tool_end( + self, + output: str, + **kwargs: Any, + ) -> None: + """Run when tool ends running.""" + _handle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + _handle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): + """Async callback manager for tool run.""" + + def get_child(self) -> AsyncCallbackManager: + """Get a child callback manager.""" + manager = AsyncCallbackManager([], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + return manager + + async def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + await _ahandle_event( + self.handlers, + "on_tool_end", + "ignore_agent", + output, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + async def on_tool_error( + self, + error: Union[Exception, KeyboardInterrupt], + **kwargs: Any, + ) -> None: + """Run when tool errors.""" + await _ahandle_event( + self.handlers, + "on_tool_error", + "ignore_agent", + error, + run_id=self.run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + +class CallbackManager(BaseCallbackManager): + """Callback manager that can be used to handle callbacks from langchain.""" + + def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForLLMRun: + """Run when LLM starts running.""" + if run_id is None: + run_id = uuid4() + + _handle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + prompts, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return CallbackManagerForLLMRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForChainRun: + """Run when chain starts running.""" + if run_id is None: + run_id = uuid4() + + _handle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return CallbackManagerForChainRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForToolRun: + """Run when tool starts running.""" + if run_id is None: + run_id = uuid4() + + _handle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return CallbackManagerForToolRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + ) -> CallbackManager: + """Configure the callback manager.""" + return _configure(cls, inheritable_callbacks, local_callbacks, verbose) + + +class AsyncCallbackManager(BaseCallbackManager): + """Async callback manager that can be used to handle callbacks from LangChain.""" + + @property + def is_async(self) -> bool: + """Return whether the handler is async.""" + return True + + async def on_llm_start( + self, + serialized: Dict[str, Any], + prompts: List[str], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForLLMRun: + """Run when LLM starts running.""" + if run_id is None: + run_id = uuid4() + + await _ahandle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + prompts, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return AsyncCallbackManagerForLLMRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + async def on_chain_start( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForChainRun: + """Run when chain starts running.""" + if run_id is None: + run_id = uuid4() + + await _ahandle_event( + self.handlers, + "on_chain_start", + "ignore_chain", + serialized, + inputs, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return AsyncCallbackManagerForChainRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + async def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + run_id: Optional[UUID] = None, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> AsyncCallbackManagerForToolRun: + """Run when tool starts running.""" + if run_id is None: + run_id = uuid4() + + await _ahandle_event( + self.handlers, + "on_tool_start", + "ignore_agent", + serialized, + input_str, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return AsyncCallbackManagerForToolRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + + @classmethod + def configure( + cls, + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, + ) -> AsyncCallbackManager: + """Configure the callback manager.""" + return _configure(cls, inheritable_callbacks, local_callbacks, verbose) + + +T = TypeVar("T", CallbackManager, AsyncCallbackManager) + + +def _configure( + callback_manager_cls: Type[T], + inheritable_callbacks: Callbacks = None, + local_callbacks: Callbacks = None, + verbose: bool = False, +) -> T: + """Configure the callback manager.""" + callback_manager = callback_manager_cls([]) + if inheritable_callbacks or local_callbacks: + if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: + inheritable_callbacks_ = inheritable_callbacks or [] + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks_, + inheritable_handlers=inheritable_callbacks_, + ) + else: + callback_manager = callback_manager_cls( + handlers=inheritable_callbacks.handlers, + inheritable_handlers=inheritable_callbacks.inheritable_handlers, + parent_run_id=inheritable_callbacks.parent_run_id, + ) + callback_manager = copy.deepcopy(callback_manager) + local_handlers_ = ( + local_callbacks + if isinstance(local_callbacks, list) + else (local_callbacks.handlers if local_callbacks else []) + ) + for handler in local_handlers_: + callback_manager.add_handler(copy.deepcopy(handler), False) + + tracer = tracing_callback_var.get() + open_ai = openai_callback_var.get() + tracing_enabled_ = ( + os.environ.get("LANGCHAIN_TRACING") is not None + or tracer is not None + or os.environ.get("LANGCHAIN_HANDLER") is not None + ) + tracer_session = os.environ.get("LANGCHAIN_SESSION") + if tracer_session is None: + tracer_session = "default" + if verbose or tracing_enabled_ or open_ai is not None: + if verbose and not any( + isinstance(handler, StdOutCallbackHandler) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(StdOutCallbackHandler(), False) + + if tracing_enabled_ and not any( + isinstance(handler, LangChainTracer) + for handler in callback_manager.handlers + ): + if tracer: + callback_manager.add_handler(tracer, True) + else: + handler = LangChainTracer() + handler.load_session(tracer_session) + callback_manager.add_handler(handler, True) + if open_ai is not None and not any( + isinstance(handler, OpenAICallbackHandler) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(open_ai, True) + + return callback_manager diff --git a/langchain/callbacks/openai_info.py b/langchain/callbacks/openai_info.py index 42005acb..3c77f1f2 100644 --- a/langchain/callbacks/openai_info.py +++ b/langchain/callbacks/openai_info.py @@ -148,16 +148,6 @@ class OpenAICallbackHandler(BaseCallbackHandler): """Do nothing.""" pass - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - """Run when agent ends.""" - pass - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run on agent action.""" pass @@ -167,3 +157,11 @@ class OpenAICallbackHandler(BaseCallbackHandler): ) -> None: """Run on agent end.""" pass + + def __copy__(self) -> "OpenAICallbackHandler": + """Return a copy of the callback handler.""" + return self + + def __deepcopy__(self, memo: Any) -> "OpenAICallbackHandler": + """Return a deep copy of the callback handler.""" + return self diff --git a/langchain/callbacks/shared.py b/langchain/callbacks/shared.py deleted file mode 100644 index 225b183e..00000000 --- a/langchain/callbacks/shared.py +++ /dev/null @@ -1,127 +0,0 @@ -"""A shared CallbackManager.""" - -import threading -from typing import Any, Dict, List, Union - -from langchain.callbacks.base import ( - BaseCallbackHandler, - BaseCallbackManager, - CallbackManager, -) -from langchain.schema import AgentAction, AgentFinish, LLMResult - - -class Singleton: - """A thread-safe singleton class that can be inherited from.""" - - _instance = None - _lock = threading.Lock() - - def __new__(cls) -> Any: - """Create a new shared instance of the class.""" - if cls._instance is None: - with cls._lock: - # Another thread could have created the instance - # before we acquired the lock. So check that the - # instance is still nonexistent. - if not cls._instance: - cls._instance = super().__new__(cls) - return cls._instance - - -class SharedCallbackManager(Singleton, BaseCallbackManager): - """A thread-safe singleton CallbackManager.""" - - _callback_manager: CallbackManager = CallbackManager(handlers=[]) - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - with self._lock: - self._callback_manager.on_llm_start(serialized, prompts, **kwargs) - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - with self._lock: - self._callback_manager.on_llm_end(response, **kwargs) - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run when LLM generates a new token.""" - with self._lock: - self._callback_manager.on_llm_new_token(token, **kwargs) - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" - with self._lock: - self._callback_manager.on_llm_error(error, **kwargs) - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" - with self._lock: - self._callback_manager.on_chain_start(serialized, inputs, **kwargs) - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - with self._lock: - self._callback_manager.on_chain_end(outputs, **kwargs) - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when chain errors.""" - with self._lock: - self._callback_manager.on_chain_error(error, **kwargs) - - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - with self._lock: - self._callback_manager.on_tool_start(serialized, input_str, **kwargs) - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - with self._lock: - self._callback_manager.on_agent_action(action, **kwargs) - - def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - with self._lock: - self._callback_manager.on_tool_end(output, **kwargs) - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when tool errors.""" - with self._lock: - self._callback_manager.on_tool_error(error, **kwargs) - - def on_text(self, text: str, **kwargs: Any) -> None: - """Run on arbitrary text.""" - with self._lock: - self._callback_manager.on_text(text, **kwargs) - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run on agent end.""" - with self._lock: - self._callback_manager.on_agent_finish(finish, **kwargs) - - def add_handler(self, callback: BaseCallbackHandler) -> None: - """Add a callback to the callback manager.""" - with self._lock: - self._callback_manager.add_handler(callback) - - def remove_handler(self, callback: BaseCallbackHandler) -> None: - """Remove a callback from the callback manager.""" - with self._lock: - self._callback_manager.remove_handler(callback) - - def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None: - """Set handlers as the only handlers on the callback manager.""" - with self._lock: - self._callback_manager.handlers = handlers diff --git a/langchain/callbacks/stdout.py b/langchain/callbacks/stdout.py index 18eb0d21..90b0a83e 100644 --- a/langchain/callbacks/stdout.py +++ b/langchain/callbacks/stdout.py @@ -91,7 +91,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): text: str, color: Optional[str] = None, end: str = "", - **kwargs: Optional[str], + **kwargs: Any, ) -> None: """Run when agent ends.""" print_text(text, color=color if color else self.color, end=end) diff --git a/langchain/callbacks/tracers/__init__.py b/langchain/callbacks/tracers/__init__.py index 8db5367f..5dd69b48 100644 --- a/langchain/callbacks/tracers/__init__.py +++ b/langchain/callbacks/tracers/__init__.py @@ -1,12 +1,5 @@ """Tracers that record execution of LangChain runs.""" -from langchain.callbacks.tracers.base import SharedTracer, Tracer -from langchain.callbacks.tracers.langchain import BaseLangChainTracer +from langchain.callbacks.tracers.langchain import LangChainTracer - -class SharedLangChainTracer(SharedTracer, BaseLangChainTracer): - """Shared tracer that records LangChain execution to LangChain endpoint.""" - - -class LangChainTracer(Tracer, BaseLangChainTracer): - """Tracer that records LangChain execution to LangChain endpoint.""" +__all__ = ["LangChainTracer"] diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 2a99c1c8..a7d3b322 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -1,14 +1,12 @@ """Base interfaces for tracing runs.""" from __future__ import annotations -import threading from abc import ABC, abstractmethod -from dataclasses import dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional, Union +from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler -from langchain.callbacks.shared import Singleton from langchain.callbacks.tracers.schemas import ( ChainRun, LLMRun, @@ -16,7 +14,7 @@ from langchain.callbacks.tracers.schemas import ( TracerSession, TracerSessionCreate, ) -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema import LLMResult class TracerException(Exception): @@ -26,13 +24,25 @@ class TracerException(Exception): class BaseTracer(BaseCallbackHandler, ABC): """Base interface for tracers.""" - @abstractmethod + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.run_map: Dict[str, Union[LLMRun, ChainRun, ToolRun]] = {} + self.session: Optional[TracerSession] = None + + @staticmethod def _add_child_run( - self, parent_run: Union[ChainRun, ToolRun], child_run: Union[LLMRun, ChainRun, ToolRun], ) -> None: """Add child run to a chain run or tool run.""" + if isinstance(child_run, LLMRun): + parent_run.child_llm_runs.append(child_run) + elif isinstance(child_run, ChainRun): + parent_run.child_chain_runs.append(child_run) + elif isinstance(child_run, ToolRun): + parent_run.child_tool_runs.append(child_run) + else: + raise TracerException(f"Invalid run type: {type(child_run)}") @abstractmethod def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: @@ -42,15 +52,11 @@ class BaseTracer(BaseCallbackHandler, ABC): def _persist_session(self, session: TracerSessionCreate) -> TracerSession: """Persist a tracing session.""" - @abstractmethod - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - def new_session(self, name: Optional[str] = None, **kwargs: Any) -> TracerSession: """NOT thread safe, do not call this method from multiple threads.""" session_create = TracerSessionCreate(name=name, extra=kwargs) session = self._persist_session(session_create) - self._session = session + self.session = session return session @abstractmethod @@ -61,283 +67,248 @@ class BaseTracer(BaseCallbackHandler, ABC): def load_default_session(self) -> TracerSession: """Load the default tracing session and set it as the Tracer's session.""" - @property - @abstractmethod - def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: - """Get the tracer stack.""" - - @property - @abstractmethod - def _execution_order(self) -> int: - """Get the execution order for a run.""" - - @_execution_order.setter - @abstractmethod - def _execution_order(self, value: int) -> None: - """Set the execution order for a run.""" - - @property - @abstractmethod - def _session(self) -> Optional[TracerSession]: - """Get the tracing session.""" - - @_session.setter - @abstractmethod - def _session(self, value: TracerSession) -> None: - """Set the tracing session.""" - def _start_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Start a trace for a run.""" - self._execution_order += 1 - - if self._stack: - if not ( - isinstance(self._stack[-1], ChainRun) - or isinstance(self._stack[-1], ToolRun) - ): + if run.parent_uuid: + parent_run = self.run_map[run.parent_uuid] + if parent_run: + if isinstance(parent_run, LLMRun): + raise TracerException( + "Cannot add child run to an LLM run. " + "LLM runs are not allowed to have children." + ) + self._add_child_run(parent_run, run) + else: raise TracerException( - f"Nested {run.__class__.__name__} can only be" - f" logged inside a ChainRun or ToolRun" + f"Parent run with UUID {run.parent_uuid} not found." ) - self._add_child_run(self._stack[-1], run) - self._stack.append(run) - def _end_trace(self) -> None: + self.run_map[run.uuid] = run + + def _end_trace(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """End a trace for a run.""" - run = self._stack.pop() - if not self._stack: - self._execution_order = 1 + if not run.parent_uuid: self._persist_run(run) + else: + parent_run = self.run_map.get(run.parent_uuid) + if parent_run is None: + raise TracerException( + f"Parent run with UUID {run.parent_uuid} not found." + ) + if isinstance(parent_run, LLMRun): + raise TracerException("LLM Runs are not allowed to have children. ") + if run.child_execution_order > parent_run.child_execution_order: + parent_run.child_execution_order = run.child_execution_order + self.run_map.pop(run.uuid) + + def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: + """Get the execution order for a run.""" + if parent_run_id is None: + return 1 + + parent_run = self.run_map.get(parent_run_id) + if parent_run is None: + raise TracerException(f"Parent run with UUID {parent_run_id} not found.") + + if isinstance(parent_run, LLMRun): + raise TracerException("LLM Runs are not allowed to have children. ") + + return parent_run.child_execution_order + 1 def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, + serialized: Dict[str, Any], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" - if self._session is None: - raise TracerException( - "Initialize a session with `new_session()` before starting a trace." - ) + if self.session is None: + self.session = self.load_default_session() + run_id_ = str(run_id) + parent_run_id_ = str(parent_run_id) if parent_run_id else None + + execution_order = self._get_execution_order(parent_run_id_) llm_run = LLMRun( + uuid=run_id_, + parent_uuid=parent_run_id_, serialized=serialized, prompts=prompts, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self._execution_order, - session_id=self._session.id, - id=self._generate_id(), + execution_order=execution_order, + child_execution_order=execution_order, + session_id=self.session.id, ) self._start_trace(llm_run) - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Handle a new token for an LLM run.""" - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for an LLM run.""" - if not self._stack or not isinstance(self._stack[-1], LLMRun): + if not run_id: + raise TracerException("No run_id provided for on_llm_end callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or not isinstance(llm_run, LLMRun): raise TracerException("No LLMRun found to be traced") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].response = response - - self._end_trace() + llm_run.response = response + llm_run.end_time = datetime.utcnow() + self._end_trace(llm_run) def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + **kwargs: Any, ) -> None: """Handle an error for an LLM run.""" - if not self._stack or not isinstance(self._stack[-1], LLMRun): + if not run_id: + raise TracerException("No run_id provided for on_llm_error callback.") + + run_id_ = str(run_id) + llm_run = self.run_map.get(run_id_) + if llm_run is None or not isinstance(llm_run, LLMRun): raise TracerException("No LLMRun found to be traced") - self._stack[-1].error = repr(error) - self._stack[-1].end_time = datetime.utcnow() - - self._end_trace() + llm_run.error = repr(error) + llm_run.end_time = datetime.utcnow() + self._end_trace(llm_run) def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, ) -> None: """Start a trace for a chain run.""" - if self._session is None: - raise TracerException( - "Initialize a session with `new_session()` before starting a trace." - ) + if self.session is None: + self.session = self.load_default_session() + run_id_ = str(run_id) + parent_run_id_ = str(parent_run_id) if parent_run_id else None + + execution_order = self._get_execution_order(parent_run_id_) chain_run = ChainRun( + uuid=run_id_, + parent_uuid=parent_run_id_, serialized=serialized, inputs=inputs, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self._execution_order, + execution_order=execution_order, + child_execution_order=execution_order, child_runs=[], - session_id=self._session.id, - id=self._generate_id(), + session_id=self.session.id, ) self._start_trace(chain_run) - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end( + self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any + ) -> None: """End a trace for a chain run.""" - if not self._stack or not isinstance(self._stack[-1], ChainRun): + run_id_ = str(run_id) + + chain_run = self.run_map.get(run_id_) + if chain_run is None or not isinstance(chain_run, ChainRun): raise TracerException("No ChainRun found to be traced") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].outputs = outputs - - self._end_trace() + chain_run.outputs = outputs + chain_run.end_time = datetime.utcnow() + self._end_trace(chain_run) def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + **kwargs: Any, ) -> None: """Handle an error for a chain run.""" - if not self._stack or not isinstance(self._stack[-1], ChainRun): + run_id_ = str(run_id) + + chain_run = self.run_map.get(run_id_) + if chain_run is None or not isinstance(chain_run, ChainRun): raise TracerException("No ChainRun found to be traced") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].error = repr(error) - - self._end_trace() + chain_run.error = repr(error) + chain_run.end_time = datetime.utcnow() + self._end_trace(chain_run) def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + self, + serialized: Dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, ) -> None: """Start a trace for a tool run.""" - if self._session is None: - raise TracerException( - "Initialize a session with `new_session()` before starting a trace." - ) + if self.session is None: + self.session = self.load_default_session() + run_id_ = str(run_id) + parent_run_id_ = str(parent_run_id) if parent_run_id else None + + execution_order = self._get_execution_order(parent_run_id_) tool_run = ToolRun( + uuid=run_id_, + parent_uuid=parent_run_id_, serialized=serialized, # TODO: this is duplicate info as above, not needed. action=str(serialized), tool_input=input_str, extra=kwargs, start_time=datetime.utcnow(), - execution_order=self._execution_order, + execution_order=execution_order, + child_execution_order=execution_order, child_runs=[], - session_id=self._session.id, - id=self._generate_id(), + session_id=self.session.id, ) self._start_trace(tool_run) - def on_tool_end(self, output: str, **kwargs: Any) -> None: + def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: """End a trace for a tool run.""" - if not self._stack or not isinstance(self._stack[-1], ToolRun): + run_id_ = str(run_id) + + tool_run = self.run_map.get(run_id_) + if tool_run is None or not isinstance(tool_run, ToolRun): raise TracerException("No ToolRun found to be traced") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].output = output - - self._end_trace() + tool_run.output = output + tool_run.end_time = datetime.utcnow() + self._end_trace(tool_run) def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: Union[Exception, KeyboardInterrupt], + *, + run_id: UUID, + **kwargs: Any, ) -> None: """Handle an error for a tool run.""" - if not self._stack or not isinstance(self._stack[-1], ToolRun): + run_id_ = str(run_id) + + tool_run = self.run_map.get(run_id_) + if tool_run is None or not isinstance(tool_run, ToolRun): raise TracerException("No ToolRun found to be traced") - self._stack[-1].end_time = datetime.utcnow() - self._stack[-1].error = repr(error) + tool_run.error = repr(error) + tool_run.end_time = datetime.utcnow() + self._end_trace(tool_run) - self._end_trace() + def __deepcopy__(self, memo: dict) -> BaseTracer: + """Deepcopy the tracer.""" + return self - def on_text(self, text: str, **kwargs: Any) -> None: - """Handle a text message.""" - pass - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Handle an agent finish message.""" - pass - - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Do nothing.""" - pass - - -class Tracer(BaseTracer, ABC): - """A non-thread safe implementation of the BaseTracer interface.""" - - def __init__(self) -> None: - """Initialize a tracer.""" - self._tracer_stack: List[Union[LLMRun, ChainRun, ToolRun]] = [] - self._tracer_execution_order = 1 - self._tracer_session: Optional[TracerSession] = None - - @property - def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: - """Get the tracer stack.""" - return self._tracer_stack - - @property - def _execution_order(self) -> int: - """Get the execution order for a run.""" - return self._tracer_execution_order - - @_execution_order.setter - def _execution_order(self, value: int) -> None: - """Set the execution order for a run.""" - self._tracer_execution_order = value - - @property - def _session(self) -> Optional[TracerSession]: - """Get the tracing session.""" - return self._tracer_session - - @_session.setter - def _session(self, value: TracerSession) -> None: - """Set the tracing session.""" - if self._stack: - raise TracerException( - "Cannot set a session while a trace is being recorded" - ) - self._tracer_session = value - - -@dataclass -class TracerStack(threading.local): - """A stack of runs used for logging.""" - - stack: List[Union[LLMRun, ChainRun, ToolRun]] = field(default_factory=list) - execution_order: int = 1 - - -class SharedTracer(Singleton, BaseTracer, ABC): - """A thread-safe Singleton implementation of BaseTracer.""" - - _tracer_stack = TracerStack() - _tracer_session = None - - @property - def _stack(self) -> List[Union[LLMRun, ChainRun, ToolRun]]: - """Get the tracer stack.""" - return self._tracer_stack.stack - - @property - def _execution_order(self) -> int: - """Get the execution order for a run.""" - return self._tracer_stack.execution_order - - @_execution_order.setter - def _execution_order(self, value: int) -> None: - """Set the execution order for a run.""" - self._tracer_stack.execution_order = value - - @property - def _session(self) -> Optional[TracerSession]: - """Get the tracing session.""" - return self._tracer_session - - @_session.setter - def _session(self, value: TracerSession) -> None: - """Set the tracing session.""" - with self._lock: - # TODO: currently, we are only checking current thread's stack. - # Need to make sure that we are not in the middle of a trace - # in any thread. - if self._stack: - raise TracerException( - "Cannot set a session while a trace is being recorded" - ) - self._tracer_session = value + def __copy__(self) -> BaseTracer: + """Copy the tracer.""" + return self diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index d2502204..80e7d2d2 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -3,7 +3,6 @@ from __future__ import annotations import logging import os -from abc import ABC from typing import Any, Dict, Optional, Union import requests @@ -18,14 +17,17 @@ from langchain.callbacks.tracers.schemas import ( ) -class BaseLangChainTracer(BaseTracer, ABC): +class LangChainTracer(BaseTracer): """An implementation of the SharedTracer that POSTS to the langchain endpoint.""" - always_verbose: bool = True - _endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") - _headers: Dict[str, Any] = {"Content-Type": "application/json"} - if os.getenv("LANGCHAIN_API_KEY"): - _headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + def __init__(self, session_name: str = "default", **kwargs: Any) -> None: + """Initialize the LangChain tracer.""" + super().__init__(**kwargs) + self._endpoint: str = os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000") + self._headers: Dict[str, Any] = {"Content-Type": "application/json"} + if os.getenv("LANGCHAIN_API_KEY"): + self._headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY") + self.session = self.load_session(session_name) def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: """Persist a run.""" @@ -59,54 +61,29 @@ class BaseLangChainTracer(BaseTracer, ABC): session = TracerSession(id=1, **session_create.dict()) return session - def load_session(self, session_name: str) -> TracerSession: + def _load_session(self, session_name: Optional[str] = None) -> TracerSession: """Load a session from the tracer.""" try: - r = requests.get( - f"{self._endpoint}/sessions?name={session_name}", - headers=self._headers, - ) + url = f"{self._endpoint}/sessions" + if session_name: + url += f"?name={session_name}" + r = requests.get(url, headers=self._headers) + tracer_session = TracerSession(**r.json()[0]) - self._session = tracer_session - return tracer_session except Exception as e: + session_type = "default" if not session_name else session_name logging.warning( - f"Failed to load session {session_name}, using empty session: {e}" + f"Failed to load {session_type} session, using empty session: {e}" ) tracer_session = TracerSession(id=1) - self._session = tracer_session - return tracer_session + + self.session = tracer_session + return tracer_session + + def load_session(self, session_name: str) -> TracerSession: + """Load a session with the given name from the tracer.""" + return self._load_session(session_name) def load_default_session(self) -> TracerSession: """Load the default tracing session and set it as the Tracer's session.""" - try: - r = requests.get( - f"{self._endpoint}/sessions", - headers=self._headers, - ) - # Use the first session result - tracer_session = TracerSession(**r.json()[0]) - self._session = tracer_session - return tracer_session - except Exception as e: - logging.warning(f"Failed to default session, using empty session: {e}") - tracer_session = TracerSession(id=1) - self._session = tracer_session - return tracer_session - - def _add_child_run( - self, - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], - ) -> None: - """Add child run to a chain run or tool run.""" - if isinstance(child_run, LLMRun): - parent_run.child_llm_runs.append(child_run) - elif isinstance(child_run, ChainRun): - parent_run.child_chain_runs.append(child_run) - else: - parent_run.child_tool_runs.append(child_run) - - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return None + return self._load_session("default") diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index bb77d747..ce6368ff 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -2,7 +2,7 @@ from __future__ import annotations import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -32,11 +32,13 @@ class TracerSession(TracerSessionBase): class BaseRun(BaseModel): """Base class for Run.""" - id: Optional[Union[int, str]] = None + uuid: str + parent_uuid: Optional[str] = None start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) extra: Optional[Dict[str, Any]] = None execution_order: int + child_execution_order: int serialized: Dict[str, Any] session_id: int error: Optional[str] = None @@ -57,7 +59,6 @@ class ChainRun(BaseRun): child_llm_runs: List[LLMRun] = Field(default_factory=list) child_chain_runs: List[ChainRun] = Field(default_factory=list) child_tool_runs: List[ToolRun] = Field(default_factory=list) - child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list) class ToolRun(BaseRun): @@ -69,7 +70,6 @@ class ToolRun(BaseRun): child_llm_runs: List[LLMRun] = Field(default_factory=list) child_chain_runs: List[ChainRun] = Field(default_factory=list) child_tool_runs: List[ToolRun] = Field(default_factory=list) - child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = Field(default_factory=list) ChainRun.update_forward_refs() diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 47f37b73..e5af03a0 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -5,12 +5,16 @@ from typing import Any, Dict, List, Optional from pydantic import Field, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.prompts import BasePromptTemplate from langchain.requests import TextRequestsWrapper -from langchain.schema import BaseLanguageModel class APIChain(Chain): @@ -61,16 +65,21 @@ class APIChain(Chain): ) return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.question_key] api_url = self.api_request_chain.predict( - question=question, api_docs=self.api_docs - ) - self.callback_manager.on_text( - api_url, color="green", end="\n", verbose=self.verbose + question=question, + api_docs=self.api_docs, + callbacks=_run_manager.get_child(), ) + _run_manager.on_text(api_url, color="green", end="\n", verbose=self.verbose) api_response = self.requests_wrapper.get(api_url) - self.callback_manager.on_text( + _run_manager.on_text( api_response, color="yellow", end="\n", verbose=self.verbose ) answer = self.api_answer_chain.predict( @@ -78,19 +87,27 @@ class APIChain(Chain): api_docs=self.api_docs, api_url=api_url, api_response=api_response, + callbacks=_run_manager.get_child(), ) return {self.output_key: answer} - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.question_key] api_url = await self.api_request_chain.apredict( - question=question, api_docs=self.api_docs + question=question, + api_docs=self.api_docs, + callbacks=_run_manager.get_child(), ) - self.callback_manager.on_text( + await _run_manager.on_text( api_url, color="green", end="\n", verbose=self.verbose ) api_response = await self.requests_wrapper.aget(api_url) - self.callback_manager.on_text( + await _run_manager.on_text( api_response, color="yellow", end="\n", verbose=self.verbose ) answer = await self.api_answer_chain.apredict( @@ -98,6 +115,7 @@ class APIChain(Chain): api_docs=self.api_docs, api_url=api_url, api_response=api_response, + callbacks=_run_manager.get_child(), ) return {self.output_key: answer} diff --git a/langchain/chains/api/openapi/chain.py b/langchain/chains/api/openapi/chain.py index 0f06276a..8f192271 100644 --- a/langchain/chains/api/openapi/chain.py +++ b/langchain/chains/api/openapi/chain.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, cast from pydantic import BaseModel, Field from requests import Response +from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.chains.api.openapi.requests_chain import APIRequesterChain from langchain.chains.api.openapi.response_chain import APIResponderChain from langchain.chains.base import Chain @@ -106,16 +107,21 @@ class OpenAPIEndpointChain(Chain, BaseModel): else: return {self.output_key: output} - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() intermediate_steps = {} instructions = inputs[self.instructions_key] instructions = instructions[: self.max_text_length] _api_arguments = self.api_request_chain.predict_and_parse( - instructions=instructions + instructions=instructions, callbacks=_run_manager.get_child() ) api_arguments = cast(str, _api_arguments) intermediate_steps["request_args"] = api_arguments - self.callback_manager.on_text( + _run_manager.on_text( api_arguments, color="green", end="\n", verbose=self.verbose ) if api_arguments.startswith("ERROR"): @@ -141,18 +147,17 @@ class OpenAPIEndpointChain(Chain, BaseModel): response_text = f"Error with message {str(e)}" response_text = response_text[: self.max_text_length] intermediate_steps["response_text"] = response_text - self.callback_manager.on_text( + _run_manager.on_text( response_text, color="blue", end="\n", verbose=self.verbose ) if self.api_response_chain is not None: _answer = self.api_response_chain.predict_and_parse( response=response_text, instructions=instructions, + callbacks=_run_manager.get_child(), ) answer = cast(str, _answer) - self.callback_manager.on_text( - answer, color="yellow", end="\n", verbose=self.verbose - ) + _run_manager.on_text(answer, color="yellow", end="\n", verbose=self.verbose) return self._get_output(answer, intermediate_steps) else: return self._get_output(response_text, intermediate_steps) @@ -188,6 +193,7 @@ class OpenAPIEndpointChain(Chain, BaseModel): verbose: bool = False, return_intermediate_steps: bool = False, raw_response: bool = False, + callbacks: Callbacks = None, **kwargs: Any # TODO: Handle async ) -> "OpenAPIEndpointChain": @@ -198,12 +204,17 @@ class OpenAPIEndpointChain(Chain, BaseModel): path_params=operation.path_params, ) requests_chain = APIRequesterChain.from_llm_and_typescript( - llm, typescript_definition=operation.to_typescript(), verbose=verbose + llm, + typescript_definition=operation.to_typescript(), + verbose=verbose, + callbacks=callbacks, ) if raw_response: response_chain = None else: - response_chain = APIResponderChain.from_llm(llm, verbose=verbose) + response_chain = APIResponderChain.from_llm( + llm, verbose=verbose, callbacks=callbacks + ) _requests = requests or Requests() return cls( api_request_chain=requests_chain, @@ -213,5 +224,6 @@ class OpenAPIEndpointChain(Chain, BaseModel): param_mapping=param_mapping, verbose=verbose, return_intermediate_steps=return_intermediate_steps, + callbacks=callbacks, **kwargs, ) diff --git a/langchain/chains/api/openapi/requests_chain.py b/langchain/chains/api/openapi/requests_chain.py index acc1e4c3..4bc8bd83 100644 --- a/langchain/chains/api/openapi/requests_chain.py +++ b/langchain/chains/api/openapi/requests_chain.py @@ -2,6 +2,7 @@ import json import re +from typing import Any from langchain.chains.api.openapi.prompts import REQUEST_TEMPLATE from langchain.chains.llm import LLMChain @@ -36,7 +37,11 @@ class APIRequesterChain(LLMChain): @classmethod def from_llm_and_typescript( - cls, llm: BaseLLM, typescript_definition: str, verbose: bool = True + cls, + llm: BaseLLM, + typescript_definition: str, + verbose: bool = True, + **kwargs: Any, ) -> LLMChain: """Get the request parser.""" output_parser = APIRequesterOutputParser() @@ -46,4 +51,4 @@ class APIRequesterChain(LLMChain): partial_variables={"schema": typescript_definition}, input_variables=["instructions"], ) - return cls(prompt=prompt, llm=llm, verbose=verbose) + return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) diff --git a/langchain/chains/api/openapi/response_chain.py b/langchain/chains/api/openapi/response_chain.py index 8cabbb0a..a1d7c5a1 100644 --- a/langchain/chains/api/openapi/response_chain.py +++ b/langchain/chains/api/openapi/response_chain.py @@ -2,6 +2,7 @@ import json import re +from typing import Any from langchain.chains.api.openapi.prompts import RESPONSE_TEMPLATE from langchain.chains.llm import LLMChain @@ -35,7 +36,7 @@ class APIResponderChain(LLMChain): """Get the response parser.""" @classmethod - def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain: + def from_llm(cls, llm: BaseLLM, verbose: bool = True, **kwargs: Any) -> LLMChain: """Get the response parser.""" output_parser = APIResponderOutputParser() prompt = PromptTemplate( @@ -43,4 +44,4 @@ class APIResponderChain(LLMChain): output_parser=output_parser, input_variables=["response", "instructions"], ) - return cls(prompt=prompt, llm=llm, verbose=verbose) + return cls(prompt=prompt, llm=llm, verbose=verbose, **kwargs) diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 1b1837a3..c1e8e9b2 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -1,15 +1,23 @@ """Base interface that all chains should implement.""" +import inspect import json +import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Optional, Union import yaml -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, root_validator, validator import langchain -from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + CallbackManager, + CallbackManagerForChainRun, + Callbacks, +) from langchain.schema import BaseMemory @@ -21,9 +29,8 @@ class Chain(BaseModel, ABC): """Base interface that all chains should implement.""" memory: Optional[BaseMemory] = None - callback_manager: BaseCallbackManager = Field( - default_factory=get_callback_manager, exclude=True - ) + callbacks: Callbacks = None + callback_manager: Optional[BaseCallbackManager] = None verbose: bool = Field( default_factory=_get_verbosity ) # Whether to print the response text @@ -37,15 +44,16 @@ class Chain(BaseModel, ABC): def _chain_type(self) -> str: raise NotImplementedError("Saving not supported for this chain type.") - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. - - This allows users to pass in None as callback manager, which is a nice UX. - """ - return callback_manager or get_callback_manager() + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values @validator("verbose", pre=True, always=True) def set_verbose(cls, verbose: Optional[bool]) -> bool: @@ -82,15 +90,26 @@ class Chain(BaseModel, ABC): ) @abstractmethod - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run the logic of this chain and return the output.""" - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run the logic of this chain and return the output.""" raise NotImplementedError("Async call not supported for this chain type.") def __call__( - self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False + self, + inputs: Union[Dict[str, Any], Any], + return_only_outputs: bool = False, + callbacks: Callbacks = None, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -104,21 +123,31 @@ class Chain(BaseModel, ABC): """ inputs = self.prep_inputs(inputs) - self.callback_manager.on_chain_start( + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") + run_manager = callback_manager.on_chain_start( {"name": self.__class__.__name__}, inputs, - verbose=self.verbose, ) try: - outputs = self._call(inputs) + outputs = ( + self._call(inputs, run_manager=run_manager) + if new_arg_supported + else self._call(inputs) + ) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_chain_error(e, verbose=self.verbose) + run_manager.on_chain_error(e) raise e - self.callback_manager.on_chain_end(outputs, verbose=self.verbose) + run_manager.on_chain_end(outputs) return self.prep_outputs(inputs, outputs, return_only_outputs) async def acall( - self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False + self, + inputs: Union[Dict[str, Any], Any], + return_only_outputs: bool = False, + callbacks: Callbacks = None, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -132,30 +161,24 @@ class Chain(BaseModel, ABC): """ inputs = self.prep_inputs(inputs) - if self.callback_manager.is_async: - await self.callback_manager.on_chain_start( - {"name": self.__class__.__name__}, - inputs, - verbose=self.verbose, - ) - else: - self.callback_manager.on_chain_start( - {"name": self.__class__.__name__}, - inputs, - verbose=self.verbose, - ) + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") + run_manager = await callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + inputs, + ) try: - outputs = await self._acall(inputs) + outputs = ( + await self._acall(inputs, run_manager=run_manager) + if new_arg_supported + else await self._acall(inputs) + ) except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_chain_error(e, verbose=self.verbose) - else: - self.callback_manager.on_chain_error(e, verbose=self.verbose) + await run_manager.on_chain_error(e) raise e - if self.callback_manager.is_async: - await self.callback_manager.on_chain_end(outputs, verbose=self.verbose) - else: - self.callback_manager.on_chain_end(outputs, verbose=self.verbose) + await run_manager.on_chain_end(outputs) return self.prep_outputs(inputs, outputs, return_only_outputs) def prep_outputs( @@ -195,11 +218,13 @@ class Chain(BaseModel, ABC): self._validate_inputs(inputs) return inputs - def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + def apply( + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None + ) -> List[Dict[str, str]]: """Call the chain on all inputs in the list.""" - return [self(inputs) for inputs in input_list] + return [self(inputs, callbacks=callbacks) for inputs in input_list] - def run(self, *args: Any, **kwargs: Any) -> str: + def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( @@ -210,17 +235,17 @@ class Chain(BaseModel, ABC): if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return self(args[0])[self.output_keys[0]] + return self(args[0], callbacks=callbacks)[self.output_keys[0]] if kwargs and not args: - return self(kwargs)[self.output_keys[0]] + return self(kwargs, callbacks=callbacks)[self.output_keys[0]] raise ValueError( f"`run` supported with either positional arguments or keyword arguments" f" but not both. Got args: {args} and kwargs: {kwargs}." ) - async def arun(self, *args: Any, **kwargs: Any) -> str: + async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: """Run the chain as text in, text out or multiple variables, text out.""" if len(self.output_keys) != 1: raise ValueError( @@ -231,10 +256,10 @@ class Chain(BaseModel, ABC): if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return (await self.acall(args[0]))[self.output_keys[0]] + return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]] if kwargs and not args: - return (await self.acall(kwargs))[self.output_keys[0]] + return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]] raise ValueError( f"`run` supported with either positional arguments or keyword arguments" diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 731a5528..338ea26a 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -5,6 +5,10 @@ from typing import Any, Dict, List, Optional, Tuple from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.docstore.document import Document from langchain.prompts.base import BasePromptTemplate @@ -68,19 +72,33 @@ class BaseCombineDocumentsChain(Chain, ABC): ) -> Tuple[str, dict]: """Combine documents into a single string asynchronously.""" - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, List[Document]], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} - output, extra_return_dict = self.combine_docs(docs, **other_keys) + output, extra_return_dict = self.combine_docs( + docs, callbacks=_run_manager.get_child(), **other_keys + ) extra_return_dict[self.output_key] = output return extra_return_dict - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, List[Document]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() docs = inputs[self.input_key] # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} - output, extra_return_dict = await self.acombine_docs(docs, **other_keys) + output, extra_return_dict = await self.acombine_docs( + docs, callbacks=_run_manager.get_child(), **other_keys + ) extra_return_dict[self.output_key] = output return extra_return_dict @@ -108,10 +126,17 @@ class AnalyzeDocumentChain(Chain): """ return self.combine_docs_chain.output_keys - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() document = inputs[self.input_key] docs = self.text_splitter.create_documents([document]) # Other keys are assumed to be needed for LLM prediction - other_keys = {k: v for k, v in inputs.items() if k != self.input_key} + other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key} other_keys[self.combine_docs_chain.input_key] = docs - return self.combine_docs_chain(other_keys, return_only_outputs=True) + return self.combine_docs_chain( + other_keys, return_only_outputs=True, callbacks=_run_manager.get_child() + ) diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index b439870d..8b2925de 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple from pydantic import Extra, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document @@ -129,7 +130,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): return self.combine_document_chain def combine_docs( - self, docs: List[Document], token_max: int = 3000, **kwargs: Any + self, + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, ) -> Tuple[str, dict]: """Combine documents in a map reduce manner. @@ -138,12 +143,15 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): """ results = self.llm_chain.apply( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{self.document_variable_name: d.page_content, **kwargs} for d in docs], + callbacks=callbacks, + ) + return self._process_results( + results, docs, token_max, callbacks=callbacks, **kwargs ) - return self._process_results(results, docs, token_max, **kwargs) async def acombine_docs( - self, docs: List[Document], **kwargs: Any + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Combine documents in a map reduce manner. @@ -152,15 +160,17 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): """ results = await self.llm_chain.aapply( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], + callbacks=callbacks, ) - return self._process_results(results, docs, **kwargs) + return self._process_results(results, docs, callbacks=callbacks, **kwargs) def _process_results( self, results: List[Dict], docs: List[Document], token_max: int = 3000, + callbacks: Callbacks = None, **kwargs: Any, ) -> Tuple[str, dict]: question_result_key = self.llm_chain.output_key @@ -173,7 +183,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): num_tokens = length_func(result_docs, **kwargs) def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: - return self._collapse_chain.run(input_documents=docs, **kwargs) + return self._collapse_chain.run( + input_documents=docs, callbacks=callbacks, **kwargs + ) while num_tokens is not None and num_tokens > token_max: new_result_doc_list = _split_list_of_docs( @@ -191,7 +203,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): extra_return_dict = {"intermediate_steps": _results} else: extra_return_dict = {} - output = self.combine_document_chain.run(input_documents=result_docs, **kwargs) + output = self.combine_document_chain.run( + input_documents=result_docs, callbacks=callbacks, **kwargs + ) return output, extra_return_dict @property diff --git a/langchain/chains/combine_documents/map_rerank.py b/langchain/chains/combine_documents/map_rerank.py index 35f198a9..ad8409c3 100644 --- a/langchain/chains/combine_documents/map_rerank.py +++ b/langchain/chains/combine_documents/map_rerank.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from pydantic import Extra, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document @@ -89,19 +90,22 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): ) return values - def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: """Combine documents in a map rerank manner. Combine by mapping first chain over all documents, then reranking the results. """ results = self.llm_chain.apply_and_parse( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], + callbacks=callbacks, ) return self._process_results(docs, results) async def acombine_docs( - self, docs: List[Document], **kwargs: Any + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Combine documents in a map rerank manner. @@ -109,7 +113,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): """ results = await self.llm_chain.aapply_and_parse( # FYI - this is parallelized and so it is fast. - [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] + [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], + callbacks=callbacks, ) return self._process_results(docs, results) diff --git a/langchain/chains/combine_documents/refine.py b/langchain/chains/combine_documents/refine.py index 7d1ae7ff..4b480090 100644 --- a/langchain/chains/combine_documents/refine.py +++ b/langchain/chains/combine_documents/refine.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Tuple from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, format_document, @@ -85,29 +86,31 @@ class RefineDocumentsChain(BaseCombineDocumentsChain): ) return values - def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain.""" inputs = self._construct_initial_inputs(docs, **kwargs) - res = self.initial_llm_chain.predict(**inputs) + res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs) refine_steps = [res] for doc in docs[1:]: base_inputs = self._construct_refine_inputs(doc, res) inputs = {**base_inputs, **kwargs} - res = self.refine_llm_chain.predict(**inputs) + res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs) refine_steps.append(res) return self._construct_result(refine_steps, res) async def acombine_docs( - self, docs: List[Document], **kwargs: Any + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Combine by mapping first chain over all, then stuffing into final chain.""" inputs = self._construct_initial_inputs(docs, **kwargs) - res = await self.initial_llm_chain.apredict(**inputs) + res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs) refine_steps = [res] for doc in docs[1:]: base_inputs = self._construct_refine_inputs(doc, res) inputs = {**base_inputs, **kwargs} - res = await self.refine_llm_chain.apredict(**inputs) + res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs) refine_steps.append(res) return self._construct_result(refine_steps, res) diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 5e3aa3c5..d39ce632 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import ( BaseCombineDocumentsChain, format_document, @@ -77,19 +78,21 @@ class StuffDocumentsChain(BaseCombineDocumentsChain): prompt = self.llm_chain.prompt.format(**inputs) return self.llm_chain.llm.get_num_tokens(prompt) - def combine_docs(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: - """Stuff all documents into one prompt and pass to LLM.""" - inputs = self._get_inputs(docs, **kwargs) - # Call predict on the LLM. - return self.llm_chain.predict(**inputs), {} - - async def acombine_docs( - self, docs: List[Document], **kwargs: Any + def combine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any ) -> Tuple[str, dict]: """Stuff all documents into one prompt and pass to LLM.""" inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. - return await self.llm_chain.apredict(**inputs), {} + return self.llm_chain.predict(callbacks=callbacks, **inputs), {} + + async def acombine_docs( + self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + ) -> Tuple[str, dict]: + """Stuff all documents into one prompt and pass to LLM.""" + inputs = self._get_inputs(docs, **kwargs) + # Call predict on the LLM. + return await self.llm_chain.apredict(callbacks=callbacks, **inputs), {} @property def _chain_type(self) -> str: diff --git a/langchain/chains/constitutional_ai/base.py b/langchain/chains/constitutional_ai/base.py index 7845da22..007b2092 100644 --- a/langchain/chains/constitutional_ai/base.py +++ b/langchain/chains/constitutional_ai/base.py @@ -1,13 +1,14 @@ """Chain for applying constitutional principles to the outputs of another chain.""" from typing import Any, Dict, List, Optional +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple from langchain.chains.constitutional_ai.principles import PRINCIPLES from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class ConstitutionalChain(Chain): @@ -86,11 +87,16 @@ class ConstitutionalChain(Chain): """Defines the output keys.""" return ["output"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() response = self.chain.run(**inputs) input_prompt = self.chain.prompt.format(**inputs) - self.callback_manager.on_text( + _run_manager.on_text( text="Initial response: " + response + "\n\n", verbose=self.verbose, color="yellow", @@ -103,6 +109,7 @@ class ConstitutionalChain(Chain): input_prompt=input_prompt, output_from_model=response, critique_request=constitutional_principle.critique_request, + callbacks=_run_manager.get_child(), ) critique = self._parse_critique( output_string=raw_critique, @@ -116,22 +123,23 @@ class ConstitutionalChain(Chain): critique_request=constitutional_principle.critique_request, critique=critique, revision_request=constitutional_principle.revision_request, + callbacks=_run_manager.get_child(), ).strip() response = revision - self.callback_manager.on_text( + _run_manager.on_text( text=f"Applying {constitutional_principle.name}..." + "\n\n", verbose=self.verbose, color="green", ) - self.callback_manager.on_text( + _run_manager.on_text( text="Critique: " + critique + "\n\n", verbose=self.verbose, color="blue", ) - self.callback_manager.on_text( + _run_manager.on_text( text="Updated response: " + revision + "\n\n", verbose=self.verbose, color="yellow", diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 8bd948cf..ce7e7115 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -8,6 +8,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from pydantic import Extra, Field, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain @@ -15,7 +20,7 @@ from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_ from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseMessage, BaseRetriever, Document +from langchain.schema import BaseMessage, BaseRetriever, Document from langchain.vectorstores.base import VectorStore # Depending on the memory type and configuration, the chat history format may differ. @@ -81,14 +86,20 @@ class BaseConversationalRetrievalChain(Chain): def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: """Get docs.""" - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: + callbacks = _run_manager.get_child() new_question = self.question_generator.run( - question=question, chat_history=chat_history_str + question=question, chat_history=chat_history_str, callbacks=callbacks ) else: new_question = question @@ -96,7 +107,9 @@ class BaseConversationalRetrievalChain(Chain): new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str - answer = self.combine_docs_chain.run(input_documents=docs, **new_inputs) + answer = self.combine_docs_chain.run( + input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: @@ -106,13 +119,19 @@ class BaseConversationalRetrievalChain(Chain): async def _aget_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: """Get docs.""" - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs["question"] get_chat_history = self.get_chat_history or _get_chat_history chat_history_str = get_chat_history(inputs["chat_history"]) if chat_history_str: + callbacks = _run_manager.get_child() new_question = await self.question_generator.arun( - question=question, chat_history=chat_history_str + question=question, chat_history=chat_history_str, callbacks=callbacks ) else: new_question = question @@ -120,7 +139,9 @@ class BaseConversationalRetrievalChain(Chain): new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str - answer = await self.combine_docs_chain.arun(input_documents=docs, **new_inputs) + answer = await self.combine_docs_chain.arun( + input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs + ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: diff --git a/langchain/chains/graph_qa/base.py b/langchain/chains/graph_qa/base.py index addf72f8..112338ae 100644 --- a/langchain/chains/graph_qa/base.py +++ b/langchain/chains/graph_qa/base.py @@ -1,10 +1,11 @@ """Question answering over a graph.""" from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import Field +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT from langchain.chains.llm import LLMChain @@ -51,18 +52,25 @@ class GraphQAChain(Chain): qa_chain = LLMChain(llm=llm, prompt=qa_prompt) entity_chain = LLMChain(llm=llm, prompt=entity_prompt) - return cls(qa_chain=qa_chain, entity_extraction_chain=entity_chain, **kwargs) + return cls( + qa_chain=qa_chain, + entity_extraction_chain=entity_chain, + **kwargs, + ) - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Extract entities, look up info and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] entity_string = self.entity_extraction_chain.run(question) - self.callback_manager.on_text( - "Entities Extracted:", end="\n", verbose=self.verbose - ) - self.callback_manager.on_text( + _run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose) + _run_manager.on_text( entity_string, color="green", end="\n", verbose=self.verbose ) entities = get_entities(entity_string) @@ -70,9 +78,10 @@ class GraphQAChain(Chain): for entity in entities: triplets = self.graph.get_entity_knowledge(entity) context += "\n".join(triplets) - self.callback_manager.on_text("Full Context:", end="\n", verbose=self.verbose) - self.callback_manager.on_text( - context, color="green", end="\n", verbose=self.verbose + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text(context, color="green", end="\n", verbose=self.verbose) + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=_run_manager.get_child(), ) - result = self.qa_chain({"question": question, "context": context}) return {self.output_key: result[self.qa_chain.output_key]} diff --git a/langchain/chains/hyde/base.py b/langchain/chains/hyde/base.py index f2f97470..3cd6170e 100644 --- a/langchain/chains/hyde/base.py +++ b/langchain/chains/hyde/base.py @@ -4,11 +4,12 @@ https://arxiv.org/abs/2212.10496 """ from __future__ import annotations -from typing import Dict, List +from typing import Any, Dict, List, Optional import numpy as np from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.llm import LLMChain @@ -57,18 +58,27 @@ class HypotheticalDocumentEmbedder(Chain, Embeddings): embeddings = self.embed_documents(documents) return self.combine_embeddings(embeddings) - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: """Call the internal llm chain.""" - return self.llm_chain._call(inputs) + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + return self.llm_chain(inputs, callbacks=_run_manager.get_child()) @classmethod def from_llm( - cls, llm: BaseLLM, base_embeddings: Embeddings, prompt_key: str + cls, + llm: BaseLLM, + base_embeddings: Embeddings, + prompt_key: str, + **kwargs: Any, ) -> HypotheticalDocumentEmbedder: """Load and use LLMChain for a specific prompt key.""" prompt = PROMPT_MAP[prompt_key] llm_chain = LLMChain(llm=llm, prompt=prompt) - return cls(base_embeddings=base_embeddings, llm_chain=llm_chain) + return cls(base_embeddings=base_embeddings, llm_chain=llm_chain, **kwargs) @property def _chain_type(self) -> str: diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index eb396322..db29cd93 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -5,11 +5,19 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from pydantic import Extra +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForChainRun, + CallbackManager, + CallbackManagerForChainRun, + Callbacks, +) from langchain.chains.base import Chain from langchain.input import get_colored_text from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import BaseLanguageModel, LLMResult, PromptValue +from langchain.schema import LLMResult, PromptValue class LLMChain(Chain): @@ -53,21 +61,40 @@ class LLMChain(Chain): """ return [self.output_key] - def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: - return self.apply([inputs])[0] + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + response = self.generate([inputs], run_manager=run_manager) + return self.create_outputs(response)[0] - def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + def generate( + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> LLMResult: """Generate LLM result from inputs.""" - prompts, stop = self.prep_prompts(input_list) - return self.llm.generate_prompt(prompts, stop) + prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) + return self.llm.generate_prompt( + prompts, stop, callbacks=run_manager.get_child() if run_manager else None + ) - async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + async def agenerate( + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> LLMResult: """Generate LLM result from inputs.""" prompts, stop = await self.aprep_prompts(input_list) - return await self.llm.agenerate_prompt(prompts, stop) + return await self.llm.agenerate_prompt( + prompts, stop, callbacks=run_manager.get_child() if run_manager else None + ) def prep_prompts( - self, input_list: List[Dict[str, Any]] + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Tuple[List[PromptValue], Optional[List[str]]]: """Prepare prompts from inputs.""" stop = None @@ -79,7 +106,8 @@ class LLMChain(Chain): prompt = self.prompt.format_prompt(**selected_inputs) _colored_text = get_colored_text(prompt.to_string(), "green") _text = "Prompt after formatting:\n" + _colored_text - self.callback_manager.on_text(_text, end="\n", verbose=self.verbose) + if run_manager: + run_manager.on_text(_text, end="\n", verbose=self.verbose) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." @@ -88,7 +116,9 @@ class LLMChain(Chain): return prompts, stop async def aprep_prompts( - self, input_list: List[Dict[str, Any]] + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, ) -> Tuple[List[PromptValue], Optional[List[str]]]: """Prepare prompts from inputs.""" stop = None @@ -100,12 +130,8 @@ class LLMChain(Chain): prompt = self.prompt.format_prompt(**selected_inputs) _colored_text = get_colored_text(prompt.to_string(), "green") _text = "Prompt after formatting:\n" + _colored_text - if self.callback_manager.is_async: - await self.callback_manager.on_text( - _text, end="\n", verbose=self.verbose - ) - else: - self.callback_manager.on_text(_text, end="\n", verbose=self.verbose) + if run_manager: + await run_manager.on_text(_text, end="\n", verbose=self.verbose) if "stop" in inputs and inputs["stop"] != stop: raise ValueError( "If `stop` is present in any inputs, should be present in all." @@ -113,15 +139,45 @@ class LLMChain(Chain): prompts.append(prompt) return prompts, stop - def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + def apply( + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None + ) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = self.generate(input_list) - return self.create_outputs(response) + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + run_manager = callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + {"input_list": input_list}, + ) + try: + response = self.generate(input_list, run_manager=run_manager) + except (KeyboardInterrupt, Exception) as e: + run_manager.on_chain_error(e) + raise e + outputs = self.create_outputs(response) + run_manager.on_chain_end({"outputs": outputs}) + return outputs - async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + async def aapply( + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None + ) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" - response = await self.agenerate(input_list) - return self.create_outputs(response) + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + run_manager = await callback_manager.on_chain_start( + {"name": self.__class__.__name__}, + {"input_list": input_list}, + ) + try: + response = await self.agenerate(input_list, run_manager=run_manager) + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_chain_error(e) + raise e + outputs = self.create_outputs(response) + await run_manager.on_chain_end({"outputs": outputs}) + return outputs def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: """Create outputs from response.""" @@ -131,13 +187,19 @@ class LLMChain(Chain): for generation in response.generations ] - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: - return (await self.aapply([inputs]))[0] + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + response = await self.agenerate([inputs], run_manager=run_manager) + return self.create_outputs(response)[0] - def predict(self, **kwargs: Any) -> str: + def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str: """Format prompt with kwargs and pass to LLM. Args: + callbacks: Callbacks to pass to LLMChain **kwargs: Keys to pass to prompt template. Returns: @@ -148,12 +210,13 @@ class LLMChain(Chain): completion = llm.predict(adjective="funny") """ - return self(kwargs)[self.output_key] + return self(kwargs, callbacks=callbacks)[self.output_key] - async def apredict(self, **kwargs: Any) -> str: + async def apredict(self, callbacks: Callbacks = None, **kwargs: Any) -> str: """Format prompt with kwargs and pass to LLM. Args: + callbacks: Callbacks to pass to LLMChain **kwargs: Keys to pass to prompt template. Returns: @@ -164,31 +227,33 @@ class LLMChain(Chain): completion = llm.predict(adjective="funny") """ - return (await self.acall(kwargs))[self.output_key] + return (await self.acall(kwargs, callbacks=callbacks))[self.output_key] - def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]: + def predict_and_parse( + self, callbacks: Callbacks = None, **kwargs: Any + ) -> Union[str, List[str], Dict[str, str]]: """Call predict and then parse the results.""" - result = self.predict(**kwargs) + result = self.predict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) else: return result async def apredict_and_parse( - self, **kwargs: Any + self, callbacks: Callbacks = None, **kwargs: Any ) -> Union[str, List[str], Dict[str, str]]: """Call apredict and then parse the results.""" - result = await self.apredict(**kwargs) + result = await self.apredict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) else: return result def apply_and_parse( - self, input_list: List[Dict[str, Any]] + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" - result = self.apply(input_list) + result = self.apply(input_list, callbacks=callbacks) return self._parse_result(result) def _parse_result( @@ -202,10 +267,10 @@ class LLMChain(Chain): return result async def aapply_and_parse( - self, input_list: List[Dict[str, Any]] + self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" - result = await self.aapply(input_list) + result = await self.aapply(input_list, callbacks=callbacks) return self._parse_result(result) @property diff --git a/langchain/chains/llm_bash/base.py b/langchain/chains/llm_bash/base.py index c2ae218f..468c0ba7 100644 --- a/langchain/chains/llm_bash/base.py +++ b/langchain/chains/llm_bash/base.py @@ -1,47 +1,24 @@ """Chain that interprets a prompt and executes bash code to perform bash operations.""" +from __future__ import annotations + import logging -import re -from typing import Any, Dict, List +import warnings +from typing import Any, Dict, List, Optional -from pydantic import Extra, Field +from pydantic import Extra, Field, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.prompt import PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException +from langchain.schema import OutputParserException from langchain.utilities.bash import BashProcess logger = logging.getLogger(__name__) -class BashOutputParser(BaseOutputParser): - """Parser for bash output.""" - - def parse(self, text: str) -> List[str]: - if "```bash" in text: - return self.get_code_blocks(text) - else: - raise OutputParserException( - f"Failed to parse bash output. Got: {text}", - ) - - @staticmethod - def get_code_blocks(t: str) -> List[str]: - """Get multiple code blocks from the LLM result.""" - code_blocks: List[str] = [] - # Bash markdown code blocks - pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL) - for match in pattern.finditer(t): - matched = match.group(1).strip() - if matched: - code_blocks.extend( - [line for line in matched.split("\n") if line.strip()] - ) - - return code_blocks - - class LLMBashChain(Chain): """Chain that interprets a prompt and executes bash code to perform bash operations. @@ -49,15 +26,16 @@ class LLMBashChain(Chain): .. code-block:: python from langchain import LLMBashChain, OpenAI - llm_bash = LLMBashChain(llm=OpenAI()) + llm_bash = LLMBashChain.from_llm(OpenAI()) """ - llm: BaseLanguageModel - """LLM wrapper to use.""" + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: prompt: BasePromptTemplate = PROMPT - output_parser: BaseOutputParser = Field(default_factory=BashOutputParser) + """[Deprecated]""" bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: class Config: @@ -66,6 +44,26 @@ class LLMBashChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMBashChain with an llm is deprecated. " + "Please instantiate with llm_chain or using the from_llm class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + + @root_validator + def validate_prompt(cls, values: Dict) -> Dict: + if values["llm_chain"].prompt.output_parser is None: + raise ValueError( + "The prompt used by llm_chain is expected to have an output_parser." + ) + return values + @property def input_keys(self) -> List[str]: """Expect input key. @@ -82,30 +80,34 @@ class LLMBashChain(Chain): """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain(prompt=self.prompt, llm=self.llm) + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key], verbose=self.verbose) - self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) - - t = llm_executor.predict(question=inputs[self.input_key]) - self.callback_manager.on_text(t, color="green", verbose=self.verbose) + t = self.llm_chain.predict( + question=inputs[self.input_key], callbacks=_run_manager.get_child() + ) + _run_manager.on_text(t, color="green", verbose=self.verbose) t = t.strip() try: - command_list = self.output_parser.parse(t) + parser = self.llm_chain.prompt.output_parser + command_list = parser.parse(t) # type: ignore[union-attr] except OutputParserException as e: - self.callback_manager.on_chain_error(e, verbose=self.verbose) + _run_manager.on_chain_error(e, verbose=self.verbose) raise e if self.verbose: - self.callback_manager.on_text("\nCode: ", verbose=self.verbose) - self.callback_manager.on_text( + _run_manager.on_text("\nCode: ", verbose=self.verbose) + _run_manager.on_text( str(command_list), color="yellow", verbose=self.verbose ) - output = self.bash_process.run(command_list) - - self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) + _run_manager.on_text("\nAnswer: ", verbose=self.verbose) + _run_manager.on_text(output, color="yellow", verbose=self.verbose) return {self.output_key: output} @property @@ -113,11 +115,11 @@ class LLMBashChain(Chain): return "llm_bash_chain" @classmethod - def from_bash_process( + def from_llm( cls, - bash_process: BashProcess, llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, **kwargs: Any, - ) -> "LLMBashChain": - """Create a LLMBashChain from a BashProcess.""" - return cls(llm=llm, bash_process=bash_process, **kwargs) + ) -> LLMBashChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/langchain/chains/llm_bash/prompt.py b/langchain/chains/llm_bash/prompt.py index 27dcbe57..363b5505 100644 --- a/langchain/chains/llm_bash/prompt.py +++ b/langchain/chains/llm_bash/prompt.py @@ -1,5 +1,11 @@ # flake8: noqa +from __future__ import annotations + +import re +from typing import List + from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseOutputParser, OutputParserException _PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format: @@ -19,4 +25,36 @@ That is the format. Begin! Question: {question}""" -PROMPT = PromptTemplate(input_variables=["question"], template=_PROMPT_TEMPLATE) + +class BashOutputParser(BaseOutputParser): + """Parser for bash output.""" + + def parse(self, text: str) -> List[str]: + if "```bash" in text: + return self.get_code_blocks(text) + else: + raise OutputParserException( + f"Failed to parse bash output. Got: {text}", + ) + + @staticmethod + def get_code_blocks(t: str) -> List[str]: + """Get multiple code blocks from the LLM result.""" + code_blocks: List[str] = [] + # Bash markdown code blocks + pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL) + for match in pattern.finditer(t): + matched = match.group(1).strip() + if matched: + code_blocks.extend( + [line for line in matched.split("\n") if line.strip()] + ) + + return code_blocks + + +PROMPT = PromptTemplate( + input_variables=["question"], + template=_PROMPT_TEMPLATE, + output_parser=BashOutputParser(), +) diff --git a/langchain/chains/llm_checker/base.py b/langchain/chains/llm_checker/base.py index 0702818a..ae2101e0 100644 --- a/langchain/chains/llm_checker/base.py +++ b/langchain/chains/llm_checker/base.py @@ -1,10 +1,12 @@ """Chain for question-answering with self-verification.""" +from __future__ import annotations +import warnings +from typing import Any, Dict, List, Optional -from typing import Dict, List - -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_checker.prompt import ( @@ -18,6 +20,48 @@ from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate +def _load_question_to_checked_assertions_chain( + llm: BaseLLM, + create_draft_answer_prompt: PromptTemplate, + list_assertions_prompt: PromptTemplate, + check_assertions_prompt: PromptTemplate, + revised_answer_prompt: PromptTemplate, +) -> SequentialChain: + create_draft_answer_chain = LLMChain( + llm=llm, + prompt=create_draft_answer_prompt, + output_key="statement", + ) + list_assertions_chain = LLMChain( + llm=llm, + prompt=list_assertions_prompt, + output_key="assertions", + ) + check_assertions_chain = LLMChain( + llm=llm, + prompt=check_assertions_prompt, + output_key="checked_assertions", + ) + revised_answer_chain = LLMChain( + llm=llm, + prompt=revised_answer_prompt, + output_key="revised_statement", + ) + chains = [ + create_draft_answer_chain, + list_assertions_chain, + check_assertions_chain, + revised_answer_chain, + ] + question_to_checked_assertions_chain = SequentialChain( + chains=chains, + input_variables=["question"], + output_variables=["revised_statement"], + verbose=True, + ) + return question_to_checked_assertions_chain + + class LLMCheckerChain(Chain): """Chain for question-answering with self-verification. @@ -26,16 +70,21 @@ class LLMCheckerChain(Chain): from langchain import OpenAI, LLMCheckerChain llm = OpenAI(temperature=0.7) - checker_chain = LLMCheckerChain(llm=llm) + checker_chain = LLMCheckerChain.from_llm(llm) """ - llm: BaseLLM - """LLM wrapper to use.""" + question_to_checked_assertions_chain: SequentialChain + + llm: Optional[BaseLLM] = None + """[Deprecated] LLM wrapper to use.""" create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT + """[Deprecated]""" list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT + """[Deprecated]""" check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT + """[Deprecated]""" revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT - """Prompt to use when questioning the documents.""" + """[Deprecated] Prompt to use when questioning the documents.""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -45,6 +94,34 @@ class LLMCheckerChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMCheckerChain with an llm is deprecated. " + "Please instantiate with question_to_checked_assertions_chain " + "or using the from_llm class method." + ) + if ( + "question_to_checked_assertions_chain" not in values + and values["llm"] is not None + ): + question_to_checked_assertions_chain = ( + _load_question_to_checked_assertions_chain( + values["llm"], + values.get( + "create_draft_answer_prompt", CREATE_DRAFT_ANSWER_PROMPT + ), + values.get("list_assertions_prompt", LIST_ASSERTIONS_PROMPT), + values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT), + values.get("revised_answer_prompt", REVISED_ANSWER_PROMPT), + ) + ) + values[ + "question_to_checked_assertions_chain" + ] = question_to_checked_assertions_chain + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -61,43 +138,43 @@ class LLMCheckerChain(Chain): """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] - create_draft_answer_chain = LLMChain( - llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement" + output = self.question_to_checked_assertions_chain( + {"question": question}, callbacks=_run_manager.get_child() ) - list_assertions_chain = LLMChain( - llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions" - ) - check_assertions_chain = LLMChain( - llm=self.llm, - prompt=self.check_assertions_prompt, - output_key="checked_assertions", - ) - - revised_answer_chain = LLMChain( - llm=self.llm, - prompt=self.revised_answer_prompt, - output_key="revised_statement", - ) - - chains = [ - create_draft_answer_chain, - list_assertions_chain, - check_assertions_chain, - revised_answer_chain, - ] - - question_to_checked_assertions_chain = SequentialChain( - chains=chains, - input_variables=["question"], - output_variables=["revised_statement"], - verbose=True, - ) - output = question_to_checked_assertions_chain({"question": question}) return {self.output_key: output["revised_statement"]} @property def _chain_type(self) -> str: return "llm_checker_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLLM, + create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT, + list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT, + check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, + revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT, + **kwargs: Any, + ) -> LLMCheckerChain: + question_to_checked_assertions_chain = ( + _load_question_to_checked_assertions_chain( + llm, + create_draft_answer_prompt, + list_assertions_prompt, + check_assertions_prompt, + revised_answer_prompt, + ) + ) + return cls( + question_to_checked_assertions_chain=question_to_checked_assertions_chain, + **kwargs, + ) diff --git a/langchain/chains/llm_math/base.py b/langchain/chains/llm_math/base.py index b1683b7a..1037658d 100644 --- a/langchain/chains/llm_math/base.py +++ b/langchain/chains/llm_math/base.py @@ -1,16 +1,23 @@ """Chain that interprets a prompt and executes python code to do math.""" +from __future__ import annotations + import math import re -from typing import Dict, List +import warnings +from typing import Any, Dict, List, Optional import numexpr -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.llm_math.prompt import PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LLMMathChain(Chain): @@ -20,13 +27,14 @@ class LLMMathChain(Chain): .. code-block:: python from langchain import LLMMathChain, OpenAI - llm_math = LLMMathChain(llm=OpenAI()) + llm_math = LLMMathChain.from_llm(OpenAI()) """ - llm: BaseLanguageModel - """LLM wrapper to use.""" + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" prompt: BasePromptTemplate = PROMPT - """Prompt to use to translate to python if neccessary.""" + """[Deprecated] Prompt to use to translate to python if necessary.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: @@ -36,6 +44,19 @@ class LLMMathChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMMathChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + prompt = values.get("prompt", PROMPT) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + @property def input_keys(self) -> List[str]: """Expect input key. @@ -68,15 +89,17 @@ class LLMMathChain(Chain): # Remove any leading and trailing brackets from the output return re.sub(r"^\[|\]$", "", output) - def _process_llm_result(self, llm_output: str) -> Dict[str, str]: - self.callback_manager.on_text(llm_output, color="green", verbose=self.verbose) + def _process_llm_result( + self, llm_output: str, run_manager: CallbackManagerForChainRun + ) -> Dict[str, str]: + run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) if text_match: expression = text_match.group(1) output = self._evaluate_expression(expression) - self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - self.callback_manager.on_text(output, color="yellow", verbose=self.verbose) + run_manager.on_text("\nAnswer: ", verbose=self.verbose) + run_manager.on_text(output, color="yellow", verbose=self.verbose) answer = "Answer: " + output elif llm_output.startswith("Answer:"): answer = llm_output @@ -86,30 +109,19 @@ class LLMMathChain(Chain): raise ValueError(f"unknown format from LLM: {llm_output}") return {self.output_key: answer} - async def _aprocess_llm_result(self, llm_output: str) -> Dict[str, str]: - if self.callback_manager.is_async: - await self.callback_manager.on_text( - llm_output, color="green", verbose=self.verbose - ) - else: - self.callback_manager.on_text( - llm_output, color="green", verbose=self.verbose - ) + async def _aprocess_llm_result( + self, + llm_output: str, + run_manager: AsyncCallbackManagerForChainRun, + ) -> Dict[str, str]: + await run_manager.on_text(llm_output, color="green", verbose=self.verbose) llm_output = llm_output.strip() text_match = re.search(r"^```text(.*?)```", llm_output, re.DOTALL) if text_match: expression = text_match.group(1) output = self._evaluate_expression(expression) - if self.callback_manager.is_async: - await self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - await self.callback_manager.on_text( - output, color="yellow", verbose=self.verbose - ) - else: - self.callback_manager.on_text("\nAnswer: ", verbose=self.verbose) - self.callback_manager.on_text( - output, color="yellow", verbose=self.verbose - ) + await run_manager.on_text("\nAnswer: ", verbose=self.verbose) + await run_manager.on_text(output, color="yellow", verbose=self.verbose) answer = "Answer: " + output elif llm_output.startswith("Answer:"): answer = llm_output @@ -119,31 +131,44 @@ class LLMMathChain(Chain): raise ValueError(f"unknown format from LLM: {llm_output}") return {self.output_key: answer} - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain( - prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + _run_manager.on_text(inputs[self.input_key]) + llm_output = self.llm_chain.predict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), ) - self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) - llm_output = llm_executor.predict( - question=inputs[self.input_key], stop=["```output"] - ) - return self._process_llm_result(llm_output) + return self._process_llm_result(llm_output, _run_manager) - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain( - prompt=self.prompt, llm=self.llm, callback_manager=self.callback_manager + async def _acall( + self, + inputs: Dict[str, str], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + await _run_manager.on_text(inputs[self.input_key]) + llm_output = await self.llm_chain.apredict( + question=inputs[self.input_key], + stop=["```output"], + callbacks=_run_manager.get_child(), ) - if self.callback_manager.is_async: - await self.callback_manager.on_text( - inputs[self.input_key], verbose=self.verbose - ) - else: - self.callback_manager.on_text(inputs[self.input_key], verbose=self.verbose) - llm_output = await llm_executor.apredict( - question=inputs[self.input_key], stop=["```output"] - ) - return await self._aprocess_llm_result(llm_output) + return await self._aprocess_llm_result(llm_output, _run_manager) @property def _chain_type(self) -> str: return "llm_math_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + prompt: BasePromptTemplate = PROMPT, + **kwargs: Any, + ) -> LLMMathChain: + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, **kwargs) diff --git a/langchain/chains/llm_requests.py b/langchain/chains/llm_requests.py index 4abab041..d9c05744 100644 --- a/langchain/chains/llm_requests.py +++ b/langchain/chains/llm_requests.py @@ -1,10 +1,11 @@ """Chain that hits a URL and then uses an LLM to parse results.""" from __future__ import annotations -from typing import Dict, List +from typing import Any, Dict, List, Optional from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains import LLMChain from langchain.chains.base import Chain from langchain.requests import TextRequestsWrapper @@ -61,9 +62,14 @@ class LLMRequestsChain(Chain): ) return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: from bs4 import BeautifulSoup + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() # Other keys are assumed to be needed for LLM prediction other_keys = {k: v for k, v in inputs.items() if k != self.input_key} url = inputs[self.input_key] @@ -71,7 +77,9 @@ class LLMRequestsChain(Chain): # extract the text from the html soup = BeautifulSoup(res, "html.parser") other_keys[self.requests_key] = soup.get_text()[: self.text_length] - result = self.llm_chain.predict(**other_keys) + result = self.llm_chain.predict( + callbacks=_run_manager.get_child(), **other_keys + ) return {self.output_key: result} @property diff --git a/langchain/chains/llm_summarization_checker/base.py b/langchain/chains/llm_summarization_checker/base.py index d69eecb8..e44a5cc7 100644 --- a/langchain/chains/llm_summarization_checker/base.py +++ b/langchain/chains/llm_summarization_checker/base.py @@ -1,10 +1,14 @@ """Chain for summarization with self-verification.""" +from __future__ import annotations + +import warnings from pathlib import Path -from typing import Dict, List +from typing import Any, Dict, List, Optional -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sequential import SequentialChain @@ -27,6 +31,48 @@ ARE_ALL_TRUE_PROMPT = PromptTemplate.from_file( ) +def _load_sequential_chain( + llm: BaseLLM, + create_assertions_prompt: PromptTemplate, + check_assertions_prompt: PromptTemplate, + revised_summary_prompt: PromptTemplate, + are_all_true_prompt: PromptTemplate, + verbose: bool = False, +) -> SequentialChain: + chain = SequentialChain( + chains=[ + LLMChain( + llm=llm, + prompt=create_assertions_prompt, + output_key="assertions", + verbose=verbose, + ), + LLMChain( + llm=llm, + prompt=check_assertions_prompt, + output_key="checked_assertions", + verbose=verbose, + ), + LLMChain( + llm=llm, + prompt=revised_summary_prompt, + output_key="revised_summary", + verbose=verbose, + ), + LLMChain( + llm=llm, + output_key="all_true", + prompt=are_all_true_prompt, + verbose=verbose, + ), + ], + input_variables=["summary"], + output_variables=["all_true", "revised_summary"], + verbose=verbose, + ) + return chain + + class LLMSummarizationCheckerChain(Chain): """Chain for question-answering with self-verification. @@ -35,16 +81,21 @@ class LLMSummarizationCheckerChain(Chain): from langchain import OpenAI, LLMSummarizationCheckerChain llm = OpenAI(temperature=0.0) - checker_chain = LLMSummarizationCheckerChain(llm=llm) + checker_chain = LLMSummarizationCheckerChain.from_llm(llm) """ - llm: BaseLLM - """LLM wrapper to use.""" + sequential_chain: SequentialChain + llm: Optional[BaseLLM] = None + """[Deprecated] LLM wrapper to use.""" create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT + """[Deprecated]""" check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT + """[Deprecated]""" revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT + """[Deprecated]""" are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT + """[Deprecated]""" input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -57,6 +108,25 @@ class LLMSummarizationCheckerChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an LLMSummarizationCheckerChain with an llm is " + "deprecated. Please instantiate with" + " sequential_chain argument or using the from_llm class method." + ) + if "sequential_chain" not in values and values["llm"] is not None: + values["sequential_chain"] = _load_sequential_chain( + values["llm"], + values.get("create_assertions_prompt", CREATE_ASSERTIONS_PROMPT), + values.get("check_assertions_prompt", CHECK_ASSERTIONS_PROMPT), + values.get("revised_summary_prompt", REVISED_SUMMARY_PROMPT), + values.get("are_all_true_prompt", ARE_ALL_TRUE_PROMPT), + verbose=values.get("verbose", False), + ) + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -73,46 +143,21 @@ class LLMSummarizationCheckerChain(Chain): """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() all_true = False count = 0 output = None original_input = inputs[self.input_key] chain_input = original_input - while not all_true and count < self.max_checks: - chain = SequentialChain( - chains=[ - LLMChain( - llm=self.llm, - prompt=self.create_assertions_prompt, - output_key="assertions", - verbose=self.verbose, - ), - LLMChain( - llm=self.llm, - prompt=self.check_assertions_prompt, - output_key="checked_assertions", - verbose=self.verbose, - ), - LLMChain( - llm=self.llm, - prompt=self.revised_summary_prompt, - output_key="revised_summary", - verbose=self.verbose, - ), - LLMChain( - llm=self.llm, - output_key="all_true", - prompt=self.are_all_true_prompt, - verbose=self.verbose, - ), - ], - input_variables=["summary"], - output_variables=["all_true", "revised_summary"], - verbose=self.verbose, + output = self.sequential_chain( + {"summary": chain_input}, callbacks=_run_manager.get_child() ) - output = chain({"summary": chain_input}) count += 1 if output["all_true"].strip() == "True": @@ -131,3 +176,24 @@ class LLMSummarizationCheckerChain(Chain): @property def _chain_type(self) -> str: return "llm_summarization_checker_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLLM, + create_assertions_prompt: PromptTemplate = CREATE_ASSERTIONS_PROMPT, + check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT, + revised_summary_prompt: PromptTemplate = REVISED_SUMMARY_PROMPT, + are_all_true_prompt: PromptTemplate = ARE_ALL_TRUE_PROMPT, + verbose: bool = False, + **kwargs: Any, + ) -> LLMSummarizationCheckerChain: + chain = _load_sequential_chain( + llm, + create_assertions_prompt, + check_assertions_prompt, + revised_summary_prompt, + are_all_true_prompt, + verbose=verbose, + ) + return cls(sequential_chain=chain, verbose=verbose, **kwargs) diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index 062a9431..f1b66b49 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -5,10 +5,11 @@ then combines the results with another one. """ from __future__ import annotations -from typing import Dict, List +from typing import Any, Dict, List, Optional from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -32,16 +33,26 @@ class MapReduceChain(Chain): @classmethod def from_params( - cls, llm: BaseLLM, prompt: BasePromptTemplate, text_splitter: TextSplitter + cls, + llm: BaseLLM, + prompt: BasePromptTemplate, + text_splitter: TextSplitter, + callbacks: Callbacks = None, + **kwargs: Any, ) -> MapReduceChain: """Construct a map-reduce chain that uses the chain for map and reduce.""" - llm_chain = LLMChain(llm=llm, prompt=prompt) - reduce_chain = StuffDocumentsChain(llm_chain=llm_chain) + llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=callbacks) + reduce_chain = StuffDocumentsChain(llm_chain=llm_chain, callbacks=callbacks) combine_documents_chain = MapReduceDocumentsChain( - llm_chain=llm_chain, combine_document_chain=reduce_chain + llm_chain=llm_chain, + combine_document_chain=reduce_chain, + callbacks=callbacks, ) return cls( - combine_documents_chain=combine_documents_chain, text_splitter=text_splitter + combine_documents_chain=combine_documents_chain, + text_splitter=text_splitter, + callbacks=callbacks, + **kwargs, ) class Config: @@ -66,9 +77,16 @@ class MapReduceChain(Chain): """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() # Split the larger text into smaller chunks. texts = self.text_splitter.split_text(inputs[self.input_key]) docs = [Document(page_content=text) for text in texts] - outputs = self.combine_documents_chain.run(input_documents=docs) + outputs = self.combine_documents_chain.run( + input_documents=docs, callbacks=_run_manager.get_child() + ) return {self.output_key: outputs} diff --git a/langchain/chains/moderation.py b/langchain/chains/moderation.py index 1e76c436..96528a76 100644 --- a/langchain/chains/moderation.py +++ b/langchain/chains/moderation.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional from pydantic import root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.utils import get_from_dict_or_env @@ -84,7 +85,11 @@ class OpenAIModerationChain(Chain): return error_str return text - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: text = inputs[self.input_key] results = self.client.create(text) output = self._moderate(text, results["results"][0]) diff --git a/langchain/chains/natbot/base.py b/langchain/chains/natbot/base.py index 369f0f45..452f7860 100644 --- a/langchain/chains/natbot/base.py +++ b/langchain/chains/natbot/base.py @@ -1,10 +1,12 @@ """Implement an LLM driven browser.""" from __future__ import annotations -from typing import Dict, List +import warnings +from typing import Any, Dict, List, Optional -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.natbot.prompt import PROMPT @@ -18,14 +20,15 @@ class NatBotChain(Chain): Example: .. code-block:: python - from langchain import NatBotChain, OpenAI - natbot = NatBotChain(llm=OpenAI(), objective="Buy me a new hat.") + from langchain import NatBotChain + natbot = NatBotChain.from_default("Buy me a new hat.") """ - llm: BaseLLM - """LLM wrapper to use.""" + llm_chain: LLMChain objective: str """Objective that NatBot is tasked with completing.""" + llm: Optional[BaseLLM] = None + """[Deprecated] LLM wrapper to use.""" input_url_key: str = "url" #: :meta private: input_browser_content_key: str = "browser_content" #: :meta private: previous_command: str = "" #: :meta private: @@ -37,11 +40,29 @@ class NatBotChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an NatBotChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=PROMPT) + return values + @classmethod - def from_default(cls, objective: str) -> NatBotChain: - """Load with default LLM.""" + def from_default(cls, objective: str, **kwargs: Any) -> NatBotChain: + """Load with default LLMChain.""" llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50) - return cls(llm=llm, objective=objective) + return cls.from_llm(llm, objective, **kwargs) + + @classmethod + def from_llm(cls, llm: BaseLLM, objective: str, **kwargs: Any) -> NatBotChain: + """Load from LLM.""" + llm_chain = LLMChain(llm=llm, prompt=PROMPT) + return cls(llm_chain=llm_chain, objective=objective, **kwargs) @property def input_keys(self) -> List[str]: @@ -59,15 +80,20 @@ class NatBotChain(Chain): """ return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_executor = LLMChain(prompt=PROMPT, llm=self.llm) + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() url = inputs[self.input_url_key] browser_content = inputs[self.input_browser_content_key] - llm_cmd = llm_executor.predict( + llm_cmd = self.llm_chain.predict( objective=self.objective, url=url[:100], previous_command=self.previous_command, browser_content=browser_content[:4500], + callbacks=_run_manager.get_child(), ) llm_cmd = llm_cmd.strip() self.previous_command = llm_cmd diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 0d15b90b..275680a8 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -4,24 +4,29 @@ As in https://arxiv.org/pdf/2211.10435.pdf. """ from __future__ import annotations +import warnings from typing import Any, Dict, List, Optional -from pydantic import Extra +from pydantic import Extra, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT from langchain.chains.pal.math_prompt import MATH_PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel from langchain.utilities import PythonREPL class PALChain(Chain): """Implements Program-Aided Language Models.""" - llm: BaseLanguageModel - prompt: BasePromptTemplate + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated]""" + prompt: BasePromptTemplate = MATH_PROMPT + """[Deprecated]""" stop: str = "\n\n" get_answer_expr: str = "print(solution())" python_globals: Optional[Dict[str, Any]] = None @@ -35,6 +40,19 @@ class PALChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an PALChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the one of " + "the class method constructors from_math_prompt, " + "from_colored_object_prompt." + ) + if "llm_chain" not in values and values["llm"] is not None: + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=MATH_PROMPT) + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -54,12 +72,16 @@ class PALChain(Chain): else: return [self.output_key, "intermediate_steps"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: - llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) - code = llm_chain.predict(stop=[self.stop], **inputs) - self.callback_manager.on_text( - code, color="green", end="\n", verbose=self.verbose + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + code = self.llm_chain.predict( + stop=[self.stop], callbacks=_run_manager.get_child(), **inputs ) + _run_manager.on_text(code, color="green", end="\n", verbose=self.verbose) repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals) res = repl.run(code + f"\n{self.get_answer_expr}") output = {self.output_key: res.strip()} @@ -70,9 +92,9 @@ class PALChain(Chain): @classmethod def from_math_prompt(cls, llm: BaseLanguageModel, **kwargs: Any) -> PALChain: """Load PAL from math prompt.""" + llm_chain = LLMChain(llm=llm, prompt=MATH_PROMPT) return cls( - llm=llm, - prompt=MATH_PROMPT, + llm_chain=llm_chain, stop="\n\n", get_answer_expr="print(solution())", **kwargs, @@ -83,9 +105,9 @@ class PALChain(Chain): cls, llm: BaseLanguageModel, **kwargs: Any ) -> PALChain: """Load PAL from colored object prompt.""" + llm_chain = LLMChain(llm=llm, prompt=COLORED_OBJECT_PROMPT) return cls( - llm=llm, - prompt=COLORED_OBJECT_PROMPT, + llm_chain=llm_chain, stop="\n\n\n", get_answer_expr="print(answer)", **kwargs, diff --git a/langchain/chains/prompt_selector.py b/langchain/chains/prompt_selector.py index 190907cc..e40e4f8a 100644 --- a/langchain/chains/prompt_selector.py +++ b/langchain/chains/prompt_selector.py @@ -3,10 +3,10 @@ from typing import Callable, List, Tuple from pydantic import BaseModel, Field +from langchain.base_language import BaseLanguageModel from langchain.chat_models.base import BaseChatModel from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class BasePromptSelector(BaseModel, ABC): diff --git a/langchain/chains/qa_generation/base.py b/langchain/chains/qa_generation/base.py index 66907bef..1c0ae6b9 100644 --- a/langchain/chains/qa_generation/base.py +++ b/langchain/chains/qa_generation/base.py @@ -5,11 +5,12 @@ from typing import Any, Dict, List, Optional from pydantic import Field +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter @@ -45,11 +46,14 @@ class QAGenerationChain(Chain): def output_keys(self) -> List[str]: return [self.output_key] - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, List]: docs = self.text_splitter.create_documents([inputs[self.input_key]]) - results = self.llm_chain.generate([{"text": d.page_content} for d in docs]) + results = self.llm_chain.generate( + [{"text": d.page_content} for d in docs], run_manager=run_manager + ) qa = [json.loads(res[0].text) for res in results.generations] return {self.output_key: qa} - - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: - raise NotImplementedError diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index 5c6317ed..96048a0a 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -8,6 +8,11 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -21,7 +26,6 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import ( ) from langchain.docstore.document import Document from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class BaseQAWithSourcesChain(Chain, ABC): @@ -114,9 +118,16 @@ class BaseQAWithSourcesChain(Chain, ABC): def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: """Get docs to run questioning over.""" - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() docs = self._get_docs(inputs) - answer = self.combine_documents_chain.run(input_documents=docs, **inputs) + answer = self.combine_documents_chain.run( + input_documents=docs, callbacks=_run_manager.get_child(), **inputs + ) if re.search(r"SOURCES:\s", answer): answer, sources = re.split(r"SOURCES:\s", answer) else: @@ -133,9 +144,16 @@ class BaseQAWithSourcesChain(Chain, ABC): async def _aget_docs(self, inputs: Dict[str, Any]) -> List[Document]: """Get docs to run questioning over.""" - async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() docs = await self._aget_docs(inputs) - answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs) + answer = await self.combine_documents_chain.arun( + input_documents=docs, callbacks=_run_manager.get_child(), **inputs + ) if re.search(r"SOURCES:\s", answer): answer, sources = re.split(r"SOURCES:\s", answer) else: diff --git a/langchain/chains/qa_with_sources/loading.py b/langchain/chains/qa_with_sources/loading.py index c1d923ae..2ce4c56e 100644 --- a/langchain/chains/qa_with_sources/loading.py +++ b/langchain/chains/qa_with_sources/loading.py @@ -1,6 +1,7 @@ """Load question answering with sources chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.base_language import BaseLanguageModel from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -14,7 +15,6 @@ from langchain.chains.qa_with_sources import ( ) from langchain.chains.question_answering import map_rerank_prompt from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): diff --git a/langchain/chains/query_constructor/base.py b/langchain/chains/query_constructor/base.py index 3fb80c43..dd5062a9 100644 --- a/langchain/chains/query_constructor/base.py +++ b/langchain/chains/query_constructor/base.py @@ -5,6 +5,7 @@ import json from typing import Any, Callable, List, Optional, Sequence from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain +from langchain.base_language import BaseLanguageModel from langchain.chains.query_constructor.ir import ( Comparator, Operator, @@ -20,7 +21,7 @@ from langchain.chains.query_constructor.prompt import ( ) from langchain.chains.query_constructor.schema import AttributeInfo from langchain.output_parsers.structured import parse_json_markdown -from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException +from langchain.schema import BaseOutputParser, OutputParserException class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 2ba684f0..95c24f0a 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -1,6 +1,7 @@ """Load question answering chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain @@ -15,7 +16,6 @@ from langchain.chains.question_answering import ( stuff_prompt, ) from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index dc1d68bf..2255f957 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -7,6 +7,11 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, Field, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain @@ -14,7 +19,7 @@ from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel, BaseRetriever, Document +from langchain.schema import BaseRetriever, Document from langchain.vectorstores.base import VectorStore @@ -92,7 +97,11 @@ class BaseRetrievalQA(Chain): def _get_docs(self, question: str) -> List[Document]: """Get documents to do question answering over.""" - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns @@ -104,11 +113,12 @@ class BaseRetrievalQA(Chain): res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] docs = self._get_docs(question) answer = self.combine_documents_chain.run( - input_documents=docs, question=question + input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: @@ -120,7 +130,11 @@ class BaseRetrievalQA(Chain): async def _aget_docs(self, question: str) -> List[Document]: """Get documents to do question answering over.""" - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, Any]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run get_relevant_text and llm on input query. If chain has 'return_source_documents' as 'True', returns @@ -132,11 +146,12 @@ class BaseRetrievalQA(Chain): res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] docs = await self._aget_docs(question) answer = await self.combine_documents_chain.arun( - input_documents=docs, question=question + input_documents=docs, question=question, callbacks=_run_manager.get_child() ) if self.return_source_documents: diff --git a/langchain/chains/sequential.py b/langchain/chains/sequential.py index b21dfac5..f94b5bc5 100644 --- a/langchain/chains/sequential.py +++ b/langchain/chains/sequential.py @@ -1,8 +1,12 @@ """Chain pipeline where the outputs of one step feed directly into next.""" -from typing import Dict, List +from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain from langchain.input import get_color_mapping @@ -86,17 +90,31 @@ class SequentialChain(Chain): return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: known_values = inputs.copy() + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() for i, chain in enumerate(self.chains): - outputs = chain(known_values, return_only_outputs=True) + callbacks = _run_manager.get_child() + outputs = chain(known_values, return_only_outputs=True, callbacks=callbacks) known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: known_values = inputs.copy() + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() for i, chain in enumerate(self.chains): - outputs = await chain.acall(known_values, return_only_outputs=True) + outputs = await chain.acall( + known_values, return_only_outputs=True, callbacks=callbacks + ) known_values.update(outputs) return {k: known_values[k] for k in self.output_variables} @@ -147,31 +165,37 @@ class SimpleSequentialChain(Chain): ) return values - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _input = inputs[self.input_key] color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): - _input = chain.run(_input) + _input = chain.run(_input, callbacks=_run_manager.get_child()) if self.strip_outputs: _input = _input.strip() - self.callback_manager.on_text( + _run_manager.on_text( _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose ) return {self.output_key: _input} - async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() _input = inputs[self.input_key] color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) for i, chain in enumerate(self.chains): - _input = await chain.arun(_input) + _input = await chain.arun(_input, callbacks=callbacks) if self.strip_outputs: _input = _input.strip() - if self.callback_manager.is_async: - await self.callback_manager.on_text( - _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose - ) - else: - self.callback_manager.on_text( - _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose - ) + await _run_manager.on_text( + _input, color=color_mapping[str(i)], end="\n", verbose=self.verbose + ) return {self.output_key: _input} diff --git a/langchain/chains/sql_database/base.py b/langchain/chains/sql_database/base.py index aa1a2e6d..d73d34a3 100644 --- a/langchain/chains/sql_database/base.py +++ b/langchain/chains/sql_database/base.py @@ -1,15 +1,17 @@ """Chain for interacting with SQL Database.""" from __future__ import annotations +import warnings from typing import Any, Dict, List, Optional -from pydantic import Extra, Field +from pydantic import Extra, Field, root_validator +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel from langchain.sql_database import SQLDatabase @@ -21,15 +23,16 @@ class SQLDatabaseChain(Chain): from langchain import SQLDatabaseChain, OpenAI, SQLDatabase db = SQLDatabase(...) - db_chain = SQLDatabaseChain(llm=OpenAI(), database=db) + db_chain = SQLDatabaseChain.from_llm(OpenAI(), db) """ - llm: BaseLanguageModel - """LLM wrapper to use.""" + llm_chain: LLMChain + llm: Optional[BaseLanguageModel] = None + """[Deprecated] LLM wrapper to use.""" database: SQLDatabase = Field(exclude=True) """SQL Database to connect to.""" prompt: Optional[BasePromptTemplate] = None - """Prompt to use to translate natural language to SQL.""" + """[Deprecated] Prompt to use to translate natural language to SQL.""" top_k: int = 5 """Number of results to return from the query""" input_key: str = "query" #: :meta private: @@ -45,6 +48,22 @@ class SQLDatabaseChain(Chain): extra = Extra.forbid arbitrary_types_allowed = True + @root_validator(pre=True) + def raise_deprecation(cls, values: Dict) -> Dict: + if "llm" in values: + warnings.warn( + "Directly instantiating an SQLDatabaseChain with an llm is deprecated. " + "Please instantiate with llm_chain argument or using the from_llm " + "class method." + ) + if "llm_chain" not in values and values["llm"] is not None: + database = values["database"] + prompt = values.get("prompt") or SQL_PROMPTS.get( + database.dialect, PROMPT + ) + values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) + return values + @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -64,11 +83,14 @@ class SQLDatabaseChain(Chain): else: return [self.output_key, "intermediate_steps"] - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - prompt = self.prompt or SQL_PROMPTS.get(self.database.dialect, PROMPT) - llm_chain = LLMChain(llm=self.llm, prompt=prompt) + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() input_text = f"{inputs[self.input_key]}\nSQLQuery:" - self.callback_manager.on_text(input_text, verbose=self.verbose) + _run_manager.on_text(input_text, verbose=self.verbose) # If not present, then defaults to None which is all tables. table_names_to_use = inputs.get("table_names_to_use") table_info = self.database.get_table_info(table_names=table_names_to_use) @@ -80,24 +102,26 @@ class SQLDatabaseChain(Chain): "stop": ["\nSQLResult:"], } intermediate_steps = [] - sql_cmd = llm_chain.predict(**llm_inputs) + sql_cmd = self.llm_chain.predict( + callbacks=_run_manager.get_child(), **llm_inputs + ) intermediate_steps.append(sql_cmd) - self.callback_manager.on_text(sql_cmd, color="green", verbose=self.verbose) + _run_manager.on_text(sql_cmd, color="green", verbose=self.verbose) result = self.database.run(sql_cmd) intermediate_steps.append(result) - self.callback_manager.on_text("\nSQLResult: ", verbose=self.verbose) - self.callback_manager.on_text(result, color="yellow", verbose=self.verbose) + _run_manager.on_text("\nSQLResult: ", verbose=self.verbose) + _run_manager.on_text(result, color="yellow", verbose=self.verbose) # If return direct, we just set the final result equal to the sql query if self.return_direct: final_result = result else: - self.callback_manager.on_text("\nAnswer:", verbose=self.verbose) + _run_manager.on_text("\nAnswer:", verbose=self.verbose) input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:" llm_inputs["input"] = input_text - final_result = llm_chain.predict(**llm_inputs) - self.callback_manager.on_text( - final_result, color="green", verbose=self.verbose + final_result = self.llm_chain.predict( + callbacks=_run_manager.get_child(), **llm_inputs ) + _run_manager.on_text(final_result, color="green", verbose=self.verbose) chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: chain_result["intermediate_steps"] = intermediate_steps @@ -107,6 +131,18 @@ class SQLDatabaseChain(Chain): def _chain_type(self) -> str: return "sql_database_chain" + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + db: SQLDatabase, + prompt: Optional[BasePromptTemplate] = None, + **kwargs: Any, + ) -> SQLDatabaseChain: + prompt = prompt or SQL_PROMPTS.get(db.dialect, PROMPT) + llm_chain = LLMChain(llm=llm, prompt=prompt) + return cls(llm_chain=llm_chain, database=db, **kwargs) + class SQLDatabaseSequentialChain(Chain): """Chain for querying SQL database that is a sequential chain. @@ -118,6 +154,10 @@ class SQLDatabaseSequentialChain(Chain): This is useful in cases where the number of tables in the database is large. """ + decider_chain: LLMChain + sql_chain: SQLDatabaseChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: return_intermediate_steps: bool = False @classmethod @@ -138,11 +178,6 @@ class SQLDatabaseSequentialChain(Chain): ) return cls(sql_chain=sql_chain, decider_chain=decider_chain, **kwargs) - decider_chain: LLMChain - sql_chain: SQLDatabaseChain - input_key: str = "query" #: :meta private: - output_key: str = "result" #: :meta private: - @property def input_keys(self) -> List[str]: """Return the singular input key. @@ -162,25 +197,32 @@ class SQLDatabaseSequentialChain(Chain): else: return [self.output_key, "intermediate_steps"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _table_names = self.sql_chain.database.get_usable_table_names() table_names = ", ".join(_table_names) llm_inputs = { "query": inputs[self.input_key], "table_names": table_names, } - table_names_to_use = self.decider_chain.predict_and_parse(**llm_inputs) - self.callback_manager.on_text( - "Table names to use:", end="\n", verbose=self.verbose + table_names_to_use = self.decider_chain.predict_and_parse( + callbacks=_run_manager.get_child(), **llm_inputs ) - self.callback_manager.on_text( + _run_manager.on_text("Table names to use:", end="\n", verbose=self.verbose) + _run_manager.on_text( str(table_names_to_use), color="yellow", verbose=self.verbose ) new_inputs = { self.sql_chain.input_key: inputs[self.input_key], "table_names_to_use": table_names_to_use, } - return self.sql_chain(new_inputs, return_only_outputs=True) + return self.sql_chain( + new_inputs, callbacks=_run_manager.get_child(), return_only_outputs=True + ) @property def _chain_type(self) -> str: diff --git a/langchain/chains/summarize/__init__.py b/langchain/chains/summarize/__init__.py index c31fda47..6fc835dd 100644 --- a/langchain/chains/summarize/__init__.py +++ b/langchain/chains/summarize/__init__.py @@ -1,6 +1,7 @@ """Load summarizing chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.base_language import BaseLanguageModel from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.refine import RefineDocumentsChain @@ -8,7 +9,6 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): diff --git a/langchain/chains/transform.py b/langchain/chains/transform.py index eb5cb314..90947b2b 100644 --- a/langchain/chains/transform.py +++ b/langchain/chains/transform.py @@ -1,6 +1,7 @@ """Chain that runs an arbitrary python function.""" -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain @@ -35,5 +36,9 @@ class TransformChain(Chain): """ return self.output_variables - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: return self.transform(inputs) diff --git a/langchain/chat_models/anthropic.py b/langchain/chat_models/anthropic.py index f56c6063..daed935b 100644 --- a/langchain/chat_models/anthropic.py +++ b/langchain/chat_models/anthropic.py @@ -2,6 +2,10 @@ from typing import Any, Dict, List, Optional from pydantic import Extra +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models.base import BaseChatModel from langchain.llms.anthropic import _AnthropicCommon from langchain.schema import ( @@ -85,7 +89,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): ) # trim off the trailing ' ' that might come from the "Assistant: " def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: prompt = self._convert_messages_to_prompt(messages) params: Dict[str, Any] = {"prompt": prompt, **self._default_params} @@ -98,10 +105,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): for data in stream_resp: delta = data["completion"][len(completion) :] completion = data["completion"] - self.callback_manager.on_llm_new_token( - delta, - verbose=self.verbose, - ) + if run_manager: + run_manager.on_llm_new_token( + delta, + ) else: response = self.client.completion(**params) completion = response["completion"] @@ -109,7 +116,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): return ChatResult(generations=[ChatGeneration(message=message)]) async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: prompt = self._convert_messages_to_prompt(messages) params: Dict[str, Any] = {"prompt": prompt, **self._default_params} @@ -122,15 +132,9 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): async for data in stream_resp: delta = data["completion"][len(completion) :] completion = data["completion"] - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( + if run_manager: + await run_manager.on_llm_new_token( delta, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_new_token( - delta, - verbose=self.verbose, ) else: response = await self.client.acompletion(**params) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 91816de1..fbcc08b1 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -1,21 +1,30 @@ import asyncio +import inspect +import warnings from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Dict, List, Optional -from pydantic import Extra, Field, validator +from pydantic import Extra, Field, root_validator import langchain -from langchain.callbacks import get_callback_manager +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, + Callbacks, +) from langchain.schema import ( AIMessage, - BaseLanguageModel, BaseMessage, ChatGeneration, ChatResult, HumanMessage, LLMResult, PromptValue, + get_buffer_string, ) @@ -26,7 +35,19 @@ def _get_verbosity() -> bool: class BaseChatModel(BaseLanguageModel, ABC): verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" - callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + callbacks: Callbacks = None + callback_manager: Optional[BaseCallbackManager] = None + + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values class Config: """Configuration for this pydantic object.""" @@ -34,98 +55,130 @@ class BaseChatModel(BaseLanguageModel, ABC): extra = Extra.forbid arbitrary_types_allowed = True - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. - - This allows users to pass in None as callback manager, which is a nice UX. - """ - return callback_manager or get_callback_manager() - def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: return {} def generate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Top Level call""" - results = [self._generate(m, stop=stop) for m in messages] + + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + message_strings = [get_buffer_string(m) for m in messages] + run_manager = callback_manager.on_llm_start( + {"name": self.__class__.__name__}, message_strings + ) + + new_arg_supported = inspect.signature(self._generate).parameters.get( + "run_manager" + ) + try: + results = [ + self._generate(m, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._generate(m, stop=stop) + for m in messages + ] + except (KeyboardInterrupt, Exception) as e: + run_manager.on_llm_error(e) + raise e llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] - return LLMResult(generations=generations, llm_output=llm_output) + output = LLMResult(generations=generations, llm_output=llm_output) + run_manager.on_llm_end(output) + return output async def agenerate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, + messages: List[List[BaseMessage]], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Top Level call""" - results = await asyncio.gather( - *[self._agenerate(m, stop=stop) for m in messages] + + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose ) + message_strings = [get_buffer_string(m) for m in messages] + run_manager = await callback_manager.on_llm_start( + {"name": self.__class__.__name__}, message_strings + ) + + new_arg_supported = inspect.signature(self._agenerate).parameters.get( + "run_manager" + ) + try: + results = await asyncio.gather( + *[ + self._agenerate(m, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._agenerate(m, stop=stop) + for m in messages + ] + ) + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_llm_error(e) + raise e llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] - return LLMResult(generations=generations, llm_output=llm_output) + output = LLMResult(generations=generations, llm_output=llm_output) + await run_manager.on_llm_end(output) + return output def generate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: prompt_messages = [p.to_messages() for p in prompts] - prompt_strings = [p.to_string() for p in prompts] - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose - ) - try: - output = self.generate(prompt_messages, stop=stop) - except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) - raise e - self.callback_manager.on_llm_end(output, verbose=self.verbose) - return output + return self.generate(prompt_messages, stop=stop, callbacks=callbacks) async def agenerate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: prompt_messages = [p.to_messages() for p in prompts] - prompt_strings = [p.to_string() for p in prompts] - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompt_strings, verbose=self.verbose - ) - try: - output = await self.agenerate(prompt_messages, stop=stop) - except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_error(e, verbose=self.verbose) - else: - self.callback_manager.on_llm_error(e, verbose=self.verbose) - raise e - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end(output, verbose=self.verbose) - else: - self.callback_manager.on_llm_end(output, verbose=self.verbose) - return output + return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks) @abstractmethod def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: """Top Level call""" @abstractmethod async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: """Top Level call""" def __call__( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> BaseMessage: - return self._generate(messages, stop=stop).generations[0].message + generation = self.generate( + [messages], stop=stop, callbacks=callbacks + ).generations[0][0] + if isinstance(generation, ChatGeneration): + return generation.message + else: + raise ValueError("Unexpected generation type") def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: result = self([HumanMessage(content=message)], stop=stop) @@ -134,15 +187,21 @@ class BaseChatModel(BaseLanguageModel, ABC): class SimpleChatModel(BaseChatModel): def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: - output_str = self._call(messages, stop=stop) + output_str = self._call(messages, stop=stop, run_manager=run_manager) message = AIMessage(content=output_str) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @abstractmethod def _call( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: """Simpler interface.""" diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 4bbf1ee8..cd5efb1c 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -14,6 +14,10 @@ from tenacity import ( wait_exponential, ) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models.base import BaseChatModel from langchain.schema import ( AIMessage, @@ -242,7 +246,10 @@ class ChatOpenAI(BaseChatModel): return {"token_usage": overall_token_usage, "model_name": self.model_name} def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) if self.streaming: @@ -255,10 +262,8 @@ class ChatOpenAI(BaseChatModel): role = stream_resp["choices"][0]["delta"].get("role", role) token = stream_resp["choices"][0]["delta"].get("content", "") inner_completion += token - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) + if run_manager: + run_manager.on_llm_new_token(token) message = _convert_dict_to_message( {"content": inner_completion, "role": role} ) @@ -287,7 +292,10 @@ class ChatOpenAI(BaseChatModel): return ChatResult(generations=generations, llm_output=llm_output) async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) if self.streaming: @@ -300,16 +308,8 @@ class ChatOpenAI(BaseChatModel): role = stream_resp["choices"][0]["delta"].get("role", role) token = stream_resp["choices"][0]["delta"].get("content", "") inner_completion += token - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) + if run_manager: + await run_manager.on_llm_new_token(token) message = _convert_dict_to_message( {"content": inner_completion, "role": role} ) diff --git a/langchain/chat_models/promptlayer_openai.py b/langchain/chat_models/promptlayer_openai.py index 38b66416..6f9b9a08 100644 --- a/langchain/chat_models/promptlayer_openai.py +++ b/langchain/chat_models/promptlayer_openai.py @@ -2,6 +2,10 @@ import datetime from typing import List, Optional +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chat_models import ChatOpenAI from langchain.schema import BaseMessage, ChatResult @@ -33,13 +37,16 @@ class PromptLayerChatOpenAI(ChatOpenAI): return_pl_id: Optional[bool] = False def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> ChatResult: """Call ChatOpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request request_start_time = datetime.datetime.now().timestamp() - generated_responses = super()._generate(messages, stop) + generated_responses = super()._generate(messages, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() message_dicts, params = super()._create_message_dicts(messages, stop) for i, generation in enumerate(generated_responses.generations): @@ -67,13 +74,16 @@ class PromptLayerChatOpenAI(ChatOpenAI): return generated_responses async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> ChatResult: """Call ChatOpenAI agenerate and then call PromptLayer to log.""" from promptlayer.utils import get_api_key, promptlayer_api_request_async request_start_time = datetime.datetime.now().timestamp() - generated_responses = await super()._agenerate(messages, stop) + generated_responses = await super()._agenerate(messages, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() message_dicts, params = super()._create_message_dicts(messages, stop) for i, generation in enumerate(generated_responses.generations): diff --git a/langchain/evaluation/agents/trajectory_eval_chain.py b/langchain/evaluation/agents/trajectory_eval_chain.py index f6f9cf08..d79171bb 100644 --- a/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/langchain/evaluation/agents/trajectory_eval_chain.py @@ -1,6 +1,7 @@ """A chain for evaluating ReAct style agents.""" from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.chat_models import ChatOpenAI @@ -94,7 +95,11 @@ Tool output: {output}""" return ["score", "reasoning"] return ["score"] - def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: raw_output = self.eval_chain.run( {"tool_descriptions": self._tools_description, **inputs} ) diff --git a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py index fffb413a..ba87e5ed 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py +++ b/langchain/experimental/autonomous_agents/baby_agi/baby_agi.py @@ -1,8 +1,11 @@ +"""BabyAGI agent.""" from collections import deque from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.experimental.autonomous_agents.baby_agi.task_creation import ( TaskCreationChain, @@ -13,7 +16,6 @@ from langchain.experimental.autonomous_agents.baby_agi.task_execution import ( from langchain.experimental.autonomous_agents.baby_agi.task_prioritization import ( TaskPrioritizationChain, ) -from langchain.schema import BaseLanguageModel from langchain.vectorstores.base import VectorStore @@ -112,7 +114,11 @@ class BabyAGI(Chain, BaseModel): objective=objective, context="\n".join(context), task=task ) - def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """Run the agent.""" objective = inputs["objective"] first_task = inputs.get("first_task", "Make a todo list") diff --git a/langchain/experimental/autonomous_agents/baby_agi/task_creation.py b/langchain/experimental/autonomous_agents/baby_agi/task_creation.py index 122b0dbf..d3a1dc81 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/task_creation.py +++ b/langchain/experimental/autonomous_agents/baby_agi/task_creation.py @@ -1,5 +1,5 @@ from langchain import LLMChain, PromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel class TaskCreationChain(LLMChain): diff --git a/langchain/experimental/autonomous_agents/baby_agi/task_execution.py b/langchain/experimental/autonomous_agents/baby_agi/task_execution.py index b85619f8..aac943c0 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/task_execution.py +++ b/langchain/experimental/autonomous_agents/baby_agi/task_execution.py @@ -1,5 +1,5 @@ from langchain import LLMChain, PromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel class TaskExecutionChain(LLMChain): diff --git a/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py b/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py index 19e9d79a..d8b44c58 100644 --- a/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py +++ b/langchain/experimental/autonomous_agents/baby_agi/task_prioritization.py @@ -1,5 +1,5 @@ from langchain import LLMChain, PromptTemplate -from langchain.schema import BaseLanguageModel +from langchain.base_language import BaseLanguageModel class TaskPrioritizationChain(LLMChain): diff --git a/langchain/experimental/generative_agents/generative_agent.py b/langchain/experimental/generative_agents/generative_agent.py index ac5d951a..64780da8 100644 --- a/langchain/experimental/generative_agents/generative_agent.py +++ b/langchain/experimental/generative_agents/generative_agent.py @@ -5,9 +5,9 @@ from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, Field from langchain import LLMChain +from langchain.base_language import BaseLanguageModel from langchain.experimental.generative_agents.memory import GenerativeAgentMemory from langchain.prompts import PromptTemplate -from langchain.schema import BaseLanguageModel class GenerativeAgent(BaseModel): diff --git a/langchain/experimental/generative_agents/memory.py b/langchain/experimental/generative_agents/memory.py index 8719d1bf..5f1d65f4 100644 --- a/langchain/experimental/generative_agents/memory.py +++ b/langchain/experimental/generative_agents/memory.py @@ -3,9 +3,10 @@ import re from typing import Any, Dict, List, Optional from langchain import LLMChain +from langchain.base_language import BaseLanguageModel from langchain.prompts import PromptTemplate from langchain.retrievers import TimeWeightedVectorStoreRetriever -from langchain.schema import BaseLanguageModel, BaseMemory, Document +from langchain.schema import BaseMemory, Document logger = logging.getLogger(__name__) diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 4ec0326a..181adb0b 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional import requests from pydantic import BaseModel, Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -106,7 +107,12 @@ class AI21(LLM): """Return type of llm.""" return "ai21" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to AI21's complete endpoint. Args: diff --git a/langchain/llms/aleph_alpha.py b/langchain/llms/aleph_alpha.py index dd17bc44..bcdbebf8 100644 --- a/langchain/llms/aleph_alpha.py +++ b/langchain/llms/aleph_alpha.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -200,7 +201,12 @@ class AlephAlpha(LLM): """Return type of llm.""" return "alpeh_alpha" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Aleph Alpha's completion endpoint. Args: diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index 04dc5850..b71fe682 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -5,6 +5,10 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tupl from pydantic import BaseModel, Extra, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -158,7 +162,12 @@ class Anthropic(LLM, _AnthropicCommon): # As a last resort, wrap the prompt ourselves to emulate instruct-style. return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: r"""Call out to Anthropic's completion endpoint. Args: @@ -187,9 +196,8 @@ class Anthropic(LLM, _AnthropicCommon): for data in stream_resp: delta = data["completion"][len(current_completion) :] current_completion = data["completion"] - self.callback_manager.on_llm_new_token( - delta, verbose=self.verbose, **data - ) + if run_manager: + run_manager.on_llm_new_token(delta, **data) return current_completion response = self.client.completion( prompt=self._wrap_prompt(prompt), @@ -198,7 +206,12 @@ class Anthropic(LLM, _AnthropicCommon): ) return response["completion"] - async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> str: """Call out to Anthropic's completion endpoint asynchronously.""" stop = self._get_anthropic_stop(stop) if self.streaming: @@ -211,14 +224,8 @@ class Anthropic(LLM, _AnthropicCommon): async for data in stream_resp: delta = data["completion"][len(current_completion) :] current_completion = data["completion"] - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - delta, verbose=self.verbose, **data - ) - else: - self.callback_manager.on_llm_new_token( - delta, verbose=self.verbose, **data - ) + if run_manager: + await run_manager.on_llm_new_token(delta, **data) return current_completion response = await self.client.acompletion( prompt=self._wrap_prompt(prompt), diff --git a/langchain/llms/bananadev.py b/langchain/llms/bananadev.py index 697ebcc7..8d95c1ed 100644 --- a/langchain/llms/bananadev.py +++ b/langchain/llms/bananadev.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -80,7 +81,12 @@ class Banana(LLM): """Return type of llm.""" return "banana" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to Banana endpoint.""" try: import banana_dev as banana diff --git a/langchain/llms/base.py b/langchain/llms/base.py index dd239792..dcab983d 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -1,16 +1,25 @@ """Base interface for large language models to expose.""" +import inspect import json +import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import yaml -from pydantic import Extra, Field, validator +from pydantic import Extra, Field, root_validator, validator import langchain -from langchain.callbacks import get_callback_manager +from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager -from langchain.schema import BaseLanguageModel, Generation, LLMResult, PromptValue +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForLLMRun, + CallbackManager, + CallbackManagerForLLMRun, + Callbacks, +) +from langchain.schema import Generation, LLMResult, PromptValue def _get_verbosity() -> bool: @@ -59,7 +68,8 @@ class BaseLLM(BaseLanguageModel, ABC): cache: Optional[bool] = None verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" - callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) + callbacks: Callbacks = None + callback_manager: Optional[BaseCallbackManager] = None class Config: """Configuration for this pydantic object.""" @@ -67,15 +77,16 @@ class BaseLLM(BaseLanguageModel, ABC): extra = Extra.forbid arbitrary_types_allowed = True - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. - - This allows users to pass in None as callback manager, which is a nice UX. - """ - return callback_manager or get_callback_manager() + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values @validator("verbose", pre=True, always=True) def set_verbose(cls, verbose: Optional[bool]) -> bool: @@ -90,30 +101,45 @@ class BaseLLM(BaseLanguageModel, ABC): @abstractmethod def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompts.""" @abstractmethod async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompts.""" def generate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] - return self.generate(prompt_strings, stop=stop) + return self.generate(prompt_strings, stop=stop, callbacks=callbacks) async def agenerate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None + self, + prompts: List[PromptValue], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] - return await self.agenerate(prompt_strings, stop=stop) + return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks) def generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # If string is passed in directly no errors will be raised but outputs will @@ -124,21 +150,31 @@ class BaseLLM(BaseLanguageModel, ABC): f" argument of type {type(prompts)}." ) disregard_cache = self.cache is not None and not self.cache + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._generate).parameters.get( + "run_manager" + ) if langchain.llm_cache is None or disregard_cache: # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, verbose=self.verbose + run_manager = callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts ) try: - output = self._generate(prompts, stop=stop) + output = ( + self._generate(prompts, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._generate(prompts, stop=stop) + ) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + run_manager.on_llm_error(e) raise e - self.callback_manager.on_llm_end(output, verbose=self.verbose) + run_manager.on_llm_end(output) return output params = self.dict() params["stop"] = stop @@ -149,15 +185,19 @@ class BaseLLM(BaseLanguageModel, ABC): missing_prompts, ) = get_prompts(params, prompts) if len(missing_prompts) > 0: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose + run_manager = callback_manager.on_llm_start( + {"name": self.__class__.__name__}, missing_prompts ) try: - new_results = self._generate(missing_prompts, stop=stop) + new_results = ( + self._generate(missing_prompts, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._generate(missing_prompts, stop=stop) + ) except (KeyboardInterrupt, Exception) as e: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + run_manager.on_llm_error(e) raise e - self.callback_manager.on_llm_end(new_results, verbose=self.verbose) + run_manager.on_llm_end(new_results) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) @@ -167,36 +207,38 @@ class BaseLLM(BaseLanguageModel, ABC): return LLMResult(generations=generations, llm_output=llm_output) async def agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" disregard_cache = self.cache is not None and not self.cache + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, self.verbose + ) + new_arg_supported = inspect.signature(self._agenerate).parameters.get( + "run_manager" + ) if langchain.llm_cache is None or disregard_cache: # This happens when langchain.cache is None, but self.cache is True if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, verbose=self.verbose - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, verbose=self.verbose - ) + run_manager = await callback_manager.on_llm_start( + {"name": self.__class__.__name__}, prompts + ) try: - output = await self._agenerate(prompts, stop=stop) + output = ( + await self._agenerate(prompts, stop=stop, run_manager=run_manager) + if new_arg_supported + else await self._agenerate(prompts, stop=stop) + ) except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_error(e, verbose=self.verbose) - else: - self.callback_manager.on_llm_error(e, verbose=self.verbose) + await run_manager.on_llm_error(e, verbose=self.verbose) raise e - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end(output, verbose=self.verbose) - else: - self.callback_manager.on_llm_end(output, verbose=self.verbose) + await run_manager.on_llm_end(output, verbose=self.verbose) return output params = self.dict() params["stop"] = stop @@ -207,32 +249,22 @@ class BaseLLM(BaseLanguageModel, ABC): missing_prompts, ) = get_prompts(params, prompts) if len(missing_prompts) > 0: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, - missing_prompts, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_start( - {"name": self.__class__.__name__}, - missing_prompts, - verbose=self.verbose, - ) + run_manager = await callback_manager.on_llm_start( + {"name": self.__class__.__name__}, + missing_prompts, + ) try: - new_results = await self._agenerate(missing_prompts, stop=stop) - except (KeyboardInterrupt, Exception) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_llm_error(e, verbose=self.verbose) - else: - self.callback_manager.on_llm_error(e, verbose=self.verbose) - raise e - if self.callback_manager.is_async: - await self.callback_manager.on_llm_end( - new_results, verbose=self.verbose + new_results = ( + await self._agenerate( + missing_prompts, stop=stop, run_manager=run_manager + ) + if new_arg_supported + else await self._agenerate(missing_prompts, stop=stop) ) - else: - self.callback_manager.on_llm_end(new_results, verbose=self.verbose) + except (KeyboardInterrupt, Exception) as e: + await run_manager.on_llm_error(e) + raise e + await run_manager.on_llm_end(new_results) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) @@ -241,9 +273,15 @@ class BaseLLM(BaseLanguageModel, ABC): generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=llm_output) - def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def __call__( + self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None + ) -> str: """Check Cache and run the LLM on the given prompt and input.""" - return self.generate([prompt], stop=stop).generations[0][0].text + return ( + self.generate([prompt], stop=stop, callbacks=callbacks) + .generations[0][0] + .text + ) @property def _identifying_params(self) -> Mapping[str, Any]: @@ -307,30 +345,56 @@ class LLM(BaseLLM): """ @abstractmethod - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Run the LLM on the given prompt and input.""" - async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> str: """Run the LLM on the given prompt and input.""" raise NotImplementedError("Async generation not implemented for this LLM.") def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # TODO: add caching here. generations = [] + new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") for prompt in prompts: - text = self._call(prompt, stop=stop) + text = ( + self._call(prompt, stop=stop, run_manager=run_manager) + if new_arg_supported + else self._call(prompt, stop=stop) + ) generations.append([Generation(text=text)]) return LLMResult(generations=generations) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: """Run the LLM on the given prompt and input.""" generations = [] + new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") for prompt in prompts: - text = await self._acall(prompt, stop=stop) + text = ( + await self._acall(prompt, stop=stop, run_manager=run_manager) + if new_arg_supported + else await self._acall(prompt, stop=stop) + ) generations.append([Generation(text=text)]) return LLMResult(generations=generations) diff --git a/langchain/llms/cerebriumai.py b/langchain/llms/cerebriumai.py index 2937d7ff..3da3dfbc 100644 --- a/langchain/llms/cerebriumai.py +++ b/langchain/llms/cerebriumai.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -81,7 +82,12 @@ class CerebriumAI(LLM): """Return type of llm.""" return "cerebriumai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to CerebriumAI endpoint.""" try: from cerebrium import model_api_request diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 91894d1b..2eff193b 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -100,7 +101,12 @@ class Cohere(LLM): """Return type of llm.""" return "cohere" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Cohere's generate endpoint. Args: diff --git a/langchain/llms/deepinfra.py b/langchain/llms/deepinfra.py index 55b4c98b..f2c22823 100644 --- a/langchain/llms/deepinfra.py +++ b/langchain/llms/deepinfra.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -60,7 +61,12 @@ class DeepInfra(LLM): """Return type of llm.""" return "deepinfra" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to DeepInfra's inference API endpoint. Args: diff --git a/langchain/llms/fake.py b/langchain/llms/fake.py index aec4abb9..3df15b9c 100644 --- a/langchain/llms/fake.py +++ b/langchain/llms/fake.py @@ -1,6 +1,7 @@ """Fake LLM wrapper for testing purposes.""" from typing import Any, List, Mapping, Optional +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -15,7 +16,12 @@ class FakeListLLM(LLM): """Return type of llm.""" return "fake-list" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """First try to lookup in queries, else return 'foo' or 'bar'.""" response = self.responses[self.i] self.i += 1 diff --git a/langchain/llms/forefrontai.py b/langchain/llms/forefrontai.py index 1e34377a..8c49918a 100644 --- a/langchain/llms/forefrontai.py +++ b/langchain/llms/forefrontai.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -81,7 +82,12 @@ class ForefrontAI(LLM): """Return type of llm.""" return "forefrontai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to ForefrontAI's complete endpoint. Args: diff --git a/langchain/llms/gooseai.py b/langchain/llms/gooseai.py index ec7ca28d..571feb2b 100644 --- a/langchain/llms/gooseai.py +++ b/langchain/llms/gooseai.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -130,7 +131,12 @@ class GooseAI(LLM): """Return type of llm.""" return "gooseai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call the GooseAI API.""" params = self._default_params if stop is not None: diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index bf0300bb..ff3a6a5d 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -159,7 +160,12 @@ class GPT4All(LLM): """Return the type of llm.""" return "gpt4all" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: r"""Call out to GPT4All's generate method. Args: @@ -175,14 +181,15 @@ class GPT4All(LLM): prompt = "Once upon a time, " response = model(prompt, n_predict=55) """ - text_callback = partial( - self.callback_manager.on_llm_new_token, verbose=self.verbose - ) - text = self.client.generate( - prompt, - new_text_callback=text_callback, - **self._default_params, - ) + if run_manager: + text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) + text = self.client.generate( + prompt, + new_text_callback=text_callback, + **self._default_params, + ) + else: + text = self.client.generate(prompt, **self._default_params) if stop is not None: text = enforce_stop_tokens(text, stop) return text diff --git a/langchain/llms/huggingface_endpoint.py b/langchain/llms/huggingface_endpoint.py index 7f7561c8..66a073c1 100644 --- a/langchain/llms/huggingface_endpoint.py +++ b/langchain/llms/huggingface_endpoint.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -88,7 +89,12 @@ class HuggingFaceEndpoint(LLM): """Return type of llm.""" return "huggingface_endpoint" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to HuggingFace Hub's inference endpoint. Args: diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index e7d3af99..2838b858 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -84,7 +85,12 @@ class HuggingFaceHub(LLM): """Return type of llm.""" return "huggingface_hub" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to HuggingFace Hub's inference endpoint. Args: diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index be8787da..529cea28 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -5,6 +5,7 @@ from typing import Any, List, Mapping, Optional from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -146,7 +147,12 @@ class HuggingFacePipeline(LLM): def _llm_type(self) -> str: return "huggingface_pipeline" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: response = self.pipeline(prompt) if self.pipeline.task == "text-generation": # Text generation return includes the starter text. diff --git a/langchain/llms/llamacpp.py b/langchain/llms/llamacpp.py index b7416084..6a10af9d 100644 --- a/langchain/llms/llamacpp.py +++ b/langchain/llms/llamacpp.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Generator, List, Optional from pydantic import Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM logger = logging.getLogger(__name__) @@ -197,7 +198,12 @@ class LlamaCpp(LLM): return params - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call the Llama model and return the output. Args: @@ -219,7 +225,7 @@ class LlamaCpp(LLM): # method that yields as they are generated # and return the combined strings from the first choices's text: combined_text_output = "" - for token in self.stream(prompt=prompt, stop=stop): + for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): combined_text_output += token["choices"][0]["text"] return combined_text_output else: @@ -228,7 +234,10 @@ class LlamaCpp(LLM): return result["choices"][0]["text"] def stream( - self, prompt: str, stop: Optional[List[str]] = None + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> Generator[Dict, None, None]: """Yields results objects as they are generated in real time. @@ -268,7 +277,8 @@ class LlamaCpp(LLM): for chunk in result: token = chunk["choices"][0]["text"] log_probs = chunk["choices"][0].get("logprobs", None) - self.callback_manager.on_llm_new_token( - token=token, verbose=self.verbose, log_probs=log_probs - ) + if run_manager: + run_manager.on_llm_new_token( + token=token, verbose=self.verbose, log_probs=log_probs + ) yield chunk diff --git a/langchain/llms/manifest.py b/langchain/llms/manifest.py index f6042144..0cef977e 100644 --- a/langchain/llms/manifest.py +++ b/langchain/llms/manifest.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -42,7 +43,12 @@ class ManifestWrapper(LLM): """Return type of llm.""" return "manifest" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to LLM through Manifest.""" if stop is not None and len(stop) != 1: raise NotImplementedError( diff --git a/langchain/llms/modal.py b/langchain/llms/modal.py index 4c159a39..53f112f7 100644 --- a/langchain/llms/modal.py +++ b/langchain/llms/modal.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -69,7 +70,12 @@ class Modal(LLM): """Return type of llm.""" return "modal" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to Modal endpoint.""" params = self.model_kwargs or {} response = requests.post( diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index b3b25d0b..72f6b38e 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -111,7 +112,12 @@ class NLPCloud(LLM): """Return type of llm.""" return "nlpcloud" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to NLPCloud's create endpoint. Args: diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index e8f85a76..7e374d26 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -29,6 +29,10 @@ from tenacity import ( wait_exponential, ) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import BaseLLM from langchain.schema import Generation, LLMResult from langchain.utils import get_from_dict_or_env @@ -254,7 +258,10 @@ class BaseOpenAI(BaseLLM): return {**normal_params, **self.model_kwargs} def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Call out to OpenAI's endpoint with k unique prompts. @@ -287,11 +294,12 @@ class BaseOpenAI(BaseLLM): for stream_resp in completion_with_retry( self, prompt=_prompts, **params ): - self.callback_manager.on_llm_new_token( - stream_resp["choices"][0]["text"], - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) + if run_manager: + run_manager.on_llm_new_token( + stream_resp["choices"][0]["text"], + verbose=self.verbose, + logprobs=stream_resp["choices"][0]["logprobs"], + ) _update_response(response, stream_resp) choices.extend(response["choices"]) else: @@ -303,7 +311,10 @@ class BaseOpenAI(BaseLLM): return self.create_llm_result(choices, prompts, token_usage) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: """Call out to OpenAI's endpoint async with k unique prompts.""" params = self._invocation_params @@ -322,14 +333,8 @@ class BaseOpenAI(BaseLLM): async for stream_resp in await acompletion_with_retry( self, prompt=_prompts, **params ): - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - stream_resp["choices"][0]["text"], - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) - else: - self.callback_manager.on_llm_new_token( + if run_manager: + await run_manager.on_llm_new_token( stream_resp["choices"][0]["text"], verbose=self.verbose, logprobs=stream_resp["choices"][0]["logprobs"], @@ -705,7 +710,10 @@ class OpenAIChat(BaseLLM): return messages, params def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: messages, params = self._get_chat_params(prompts, stop) if self.streaming: @@ -714,10 +722,10 @@ class OpenAIChat(BaseLLM): for stream_resp in completion_with_retry(self, messages=messages, **params): token = stream_resp["choices"][0]["delta"].get("content", "") response += token - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) + if run_manager: + run_manager.on_llm_new_token( + token, + ) return LLMResult( generations=[[Generation(text=response)]], ) @@ -735,7 +743,10 @@ class OpenAIChat(BaseLLM): ) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: messages, params = self._get_chat_params(prompts, stop) if self.streaming: @@ -746,15 +757,9 @@ class OpenAIChat(BaseLLM): ): token = stream_resp["choices"][0]["delta"].get("content", "") response += token - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( + if run_manager: + await run_manager.on_llm_new_token( token, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, ) return LLMResult( generations=[[Generation(text=response)]], diff --git a/langchain/llms/petals.py b/langchain/llms/petals.py index ed30a28f..293d240c 100644 --- a/langchain/llms/petals.py +++ b/langchain/llms/petals.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -130,7 +131,12 @@ class Petals(LLM): """Return type of llm.""" return "petals" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call the Petals API.""" params = self._default_params inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] diff --git a/langchain/llms/pipelineai.py b/langchain/llms/pipelineai.py index 2a879622..3a29d64e 100644 --- a/langchain/llms/pipelineai.py +++ b/langchain/llms/pipelineai.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -80,7 +81,12 @@ class PipelineAI(LLM, BaseModel): """Return type of llm.""" return "pipeline_ai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to Pipeline Cloud endpoint.""" try: from pipeline import PipelineCloud diff --git a/langchain/llms/predictionguard.py b/langchain/llms/predictionguard.py index c5ba6165..4309cae5 100644 --- a/langchain/llms/predictionguard.py +++ b/langchain/llms/predictionguard.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -73,7 +74,12 @@ class PredictionGuard(LLM): """Return type of llm.""" return "predictionguard" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Prediction Guard's model proxy. Args: prompt: The prompt to pass into the model. diff --git a/langchain/llms/promptlayer_openai.py b/langchain/llms/promptlayer_openai.py index c7dd9cf3..77df8051 100644 --- a/langchain/llms/promptlayer_openai.py +++ b/langchain/llms/promptlayer_openai.py @@ -2,6 +2,10 @@ import datetime from typing import List, Optional +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms import OpenAI, OpenAIChat from langchain.schema import LLMResult @@ -33,13 +37,16 @@ class PromptLayerOpenAI(OpenAI): return_pl_id: Optional[bool] = False def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Call OpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request request_start_time = datetime.datetime.now().timestamp() - generated_responses = super()._generate(prompts, stop) + generated_responses = super()._generate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] @@ -69,12 +76,15 @@ class PromptLayerOpenAI(OpenAI): return generated_responses async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: from promptlayer.utils import get_api_key, promptlayer_api_request_async request_start_time = datetime.datetime.now().timestamp() - generated_responses = await super()._agenerate(prompts, stop) + generated_responses = await super()._agenerate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] @@ -131,13 +141,16 @@ class PromptLayerOpenAIChat(OpenAIChat): return_pl_id: Optional[bool] = False def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: """Call OpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request request_start_time = datetime.datetime.now().timestamp() - generated_responses = super()._generate(prompts, stop) + generated_responses = super()._generate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] @@ -167,12 +180,15 @@ class PromptLayerOpenAIChat(OpenAIChat): return generated_responses async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: from promptlayer.utils import get_api_key, promptlayer_api_request_async request_start_time = datetime.datetime.now().timestamp() - generated_responses = await super()._agenerate(prompts, stop) + generated_responses = await super()._agenerate(prompts, stop, run_manager) request_end_time = datetime.datetime.now().timestamp() for i in range(len(prompts)): prompt = prompts[i] diff --git a/langchain/llms/replicate.py b/langchain/llms/replicate.py index 6b487230..b1dfaf47 100644 --- a/langchain/llms/replicate.py +++ b/langchain/llms/replicate.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.utils import get_from_dict_or_env @@ -78,7 +79,12 @@ class Replicate(LLM): """Return type of model.""" return "replicate" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call to replicate endpoint.""" try: import replicate as replicate_python diff --git a/langchain/llms/rwkv.py b/langchain/llms/rwkv.py index 5c27185a..0e873d48 100644 --- a/langchain/llms/rwkv.py +++ b/langchain/llms/rwkv.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Mapping, Optional, Set from pydantic import BaseModel, Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -204,7 +205,12 @@ class RWKV(LLM, BaseModel): return decoded - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: r"""RWKV generation Args: diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index 34f236b9..80cb663f 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Generic, List, Mapping, Optional, TypeVar, Union from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -201,7 +202,12 @@ class SagemakerEndpoint(LLM): """Return type of llm.""" return "sagemaker_endpoint" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Sagemaker inference endpoint. Args: diff --git a/langchain/llms/self_hosted.py b/langchain/llms/self_hosted.py index df529d80..e7e51725 100644 --- a/langchain/llms/self_hosted.py +++ b/langchain/llms/self_hosted.py @@ -6,6 +6,7 @@ from typing import Any, Callable, List, Mapping, Optional from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens @@ -208,5 +209,10 @@ class SelfHostedPipeline(LLM): def _llm_type(self) -> str: return "self_hosted_llm" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) diff --git a/langchain/llms/self_hosted_hugging_face.py b/langchain/llms/self_hosted_hugging_face.py index dd62348c..49bd8536 100644 --- a/langchain/llms/self_hosted_hugging_face.py +++ b/langchain/llms/self_hosted_hugging_face.py @@ -5,6 +5,7 @@ from typing import Any, Callable, List, Mapping, Optional from pydantic import Extra +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.utils import enforce_stop_tokens @@ -198,5 +199,10 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline): def _llm_type(self) -> str: return "selfhosted_huggingface_pipeline" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) diff --git a/langchain/llms/stochasticai.py b/langchain/llms/stochasticai.py index 052e6efc..5d2fe730 100644 --- a/langchain/llms/stochasticai.py +++ b/langchain/llms/stochasticai.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import Extra, Field, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -80,7 +81,12 @@ class StochasticAI(LLM): """Return type of llm.""" return "stochasticai" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to StochasticAI's complete endpoint. Args: diff --git a/langchain/llms/writer.py b/langchain/llms/writer.py index a3a74f59..2cec1835 100644 --- a/langchain/llms/writer.py +++ b/langchain/llms/writer.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional import requests from pydantic import Extra, root_validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -117,7 +118,12 @@ class Writer(LLM): """Return type of llm.""" return "writer" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Call out to Writer's complete endpoint. Args: diff --git a/langchain/memory/entity.py b/langchain/memory/entity.py index e4a1aed0..70da11af 100644 --- a/langchain/memory/entity.py +++ b/langchain/memory/entity.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional from pydantic import Field +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import ( @@ -13,7 +14,7 @@ from langchain.memory.prompt import ( ) from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string +from langchain.schema import BaseMessage, get_buffer_string logger = logging.getLogger(__name__) diff --git a/langchain/memory/kg.py b/langchain/memory/kg.py index 8b2b5f6b..2c71a33c 100644 --- a/langchain/memory/kg.py +++ b/langchain/memory/kg.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Type, Union from pydantic import Field +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.graphs import NetworkxEntityGraph from langchain.graphs.networkx_graph import KnowledgeTriple, get_entities, parse_triples @@ -13,7 +14,6 @@ from langchain.memory.prompt import ( from langchain.memory.utils import get_prompt_input_key from langchain.prompts.base import BasePromptTemplate from langchain.schema import ( - BaseLanguageModel, BaseMessage, SystemMessage, get_buffer_string, diff --git a/langchain/memory/summary.py b/langchain/memory/summary.py index 4873b824..7a2d04f4 100644 --- a/langchain/memory/summary.py +++ b/langchain/memory/summary.py @@ -2,12 +2,12 @@ from typing import Any, Dict, List, Type from pydantic import BaseModel, root_validator +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts.base import BasePromptTemplate from langchain.schema import ( - BaseLanguageModel, BaseMessage, SystemMessage, get_buffer_string, diff --git a/langchain/memory/token_buffer.py b/langchain/memory/token_buffer.py index bb4da209..c5e3c01b 100644 --- a/langchain/memory/token_buffer.py +++ b/langchain/memory/token_buffer.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List +from langchain.base_language import BaseLanguageModel from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import BaseLanguageModel, BaseMessage, get_buffer_string +from langchain.schema import BaseMessage, get_buffer_string class ConversationTokenBufferMemory(BaseChatMemory): diff --git a/langchain/output_parsers/fix.py b/langchain/output_parsers/fix.py index dfa3d639..a46b2e4e 100644 --- a/langchain/output_parsers/fix.py +++ b/langchain/output_parsers/fix.py @@ -2,10 +2,11 @@ from __future__ import annotations from typing import TypeVar +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT from langchain.prompts.base import BasePromptTemplate -from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException +from langchain.schema import BaseOutputParser, OutputParserException T = TypeVar("T") diff --git a/langchain/output_parsers/retry.py b/langchain/output_parsers/retry.py index b1982608..080d1a49 100644 --- a/langchain/output_parsers/retry.py +++ b/langchain/output_parsers/retry.py @@ -2,11 +2,11 @@ from __future__ import annotations from typing import TypeVar +from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( - BaseLanguageModel, BaseOutputParser, OutputParserException, PromptValue, diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py index db4b5a67..71b4bc13 100644 --- a/langchain/retrievers/document_compressors/chain_extract.py +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -4,13 +4,12 @@ from __future__ import annotations from typing import Any, Callable, Dict, Optional, Sequence from langchain import LLMChain, PromptTemplate -from langchain.retrievers.document_compressors.base import ( - BaseDocumentCompressor, -) +from langchain.base_language import BaseLanguageModel +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.chain_extract_prompt import ( prompt_template, ) -from langchain.schema import BaseLanguageModel, BaseOutputParser, Document +from langchain.schema import BaseOutputParser, Document def default_get_input(query: str, doc: Document) -> Dict[str, Any]: diff --git a/langchain/retrievers/document_compressors/chain_filter.py b/langchain/retrievers/document_compressors/chain_filter.py index f5e33e6b..245e0051 100644 --- a/langchain/retrievers/document_compressors/chain_filter.py +++ b/langchain/retrievers/document_compressors/chain_filter.py @@ -2,14 +2,13 @@ from typing import Any, Callable, Dict, Optional, Sequence from langchain import BasePromptTemplate, LLMChain, PromptTemplate +from langchain.base_language import BaseLanguageModel from langchain.output_parsers.boolean import BooleanOutputParser -from langchain.retrievers.document_compressors.base import ( - BaseDocumentCompressor, -) +from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.chain_filter_prompt import ( prompt_template, ) -from langchain.schema import BaseLanguageModel, Document +from langchain.schema import Document def _get_default_chain_prompt() -> PromptTemplate: diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py index a9d7cad4..b74dfaca 100644 --- a/langchain/retrievers/self_query/base.py +++ b/langchain/retrievers/self_query/base.py @@ -4,13 +4,12 @@ from typing import Any, Dict, List, Optional, Type, cast from pydantic import BaseModel, Field, root_validator from langchain import LLMChain -from langchain.chains.query_constructor.base import ( - load_query_constructor_chain, -) +from langchain.base_language import BaseLanguageModel +from langchain.chains.query_constructor.base import load_query_constructor_chain from langchain.chains.query_constructor.ir import StructuredQuery, Visitor from langchain.chains.query_constructor.schema import AttributeInfo from langchain.retrievers.self_query.pinecone import PineconeTranslator -from langchain.schema import BaseLanguageModel, BaseRetriever, Document +from langchain.schema import BaseRetriever, Document from langchain.vectorstores import Pinecone, VectorStore @@ -69,7 +68,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): """ inputs = self.llm_chain.prep_inputs(query) structured_query = cast( - StructuredQuery, self.llm_chain.predict_and_parse(**inputs) + StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs) ) if self.verbose: print(structured_query) diff --git a/langchain/schema.py b/langchain/schema.py index 1e4770e0..ced8b383 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -181,49 +181,6 @@ class PromptValue(BaseModel, ABC): """Return prompt as messages.""" -def _get_num_tokens_default_method(text: str) -> int: - """Get the number of tokens present in the text.""" - # TODO: this method may not be exact. - # TODO: this method may differ based on model (eg codex). - try: - from transformers import GPT2TokenizerFast - except ImportError: - raise ValueError( - "Could not import transformers python package. " - "This is needed in order to calculate get_num_tokens. " - "Please install it with `pip install transformers`." - ) - # create a GPT-2 tokenizer instance - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - - # tokenize the text using the GPT-3 tokenizer - tokenized_text = tokenizer.tokenize(text) - - # calculate the number of tokens in the tokenized text - return len(tokenized_text) - - -class BaseLanguageModel(BaseModel, ABC): - @abstractmethod - def generate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None - ) -> LLMResult: - """Take in a list of prompt values and return an LLMResult.""" - - @abstractmethod - async def agenerate_prompt( - self, prompts: List[PromptValue], stop: Optional[List[str]] = None - ) -> LLMResult: - """Take in a list of prompt values and return an LLMResult.""" - - def get_num_tokens(self, text: str) -> int: - return _get_num_tokens_default_method(text) - - def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: - """Get the number of tokens in the message.""" - return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) - - class BaseMemory(BaseModel, ABC): """Base interface for memory in chains.""" diff --git a/langchain/tools/arxiv/tool.py b/langchain/tools/arxiv/tool.py index 83c21131..76513e27 100644 --- a/langchain/tools/arxiv/tool.py +++ b/langchain/tools/arxiv/tool.py @@ -1,5 +1,11 @@ """Tool for the Arxiv API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.arxiv import ArxivAPIWrapper @@ -18,10 +24,18 @@ class ArxivQueryRun(BaseTool): ) api_wrapper: ArxivAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Arxiv tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Arxiv tool asynchronously.""" raise NotImplementedError("ArxivAPIWrapper does not support async") diff --git a/langchain/tools/base.py b/langchain/tools/base.py index 9090c683..6799409e 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -1,6 +1,7 @@ """Base implementation for tools or skills.""" from __future__ import annotations +import warnings from abc import ABC, abstractmethod from inspect import signature from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union @@ -10,13 +11,19 @@ from pydantic import ( Extra, Field, create_model, + root_validator, validate_arguments, - validator, ) from pydantic.main import ModelMetaclass -from langchain.callbacks import get_callback_manager from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import ( + AsyncCallbackManager, + AsyncCallbackManagerForToolRun, + CallbackManager, + CallbackManagerForToolRun, + Callbacks, +) class SchemaAnnotationError(TypeError): @@ -79,7 +86,14 @@ def get_filtered_args( """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys} + return {k: schema[k] for k in valid_keys if k != "run_manager"} + + +class _SchemaConfig: + """Configuration for the pydantic model.""" + + extra = Extra.forbid + arbitrary_types_allowed = True def create_schema_from_function( @@ -87,7 +101,10 @@ def create_schema_from_function( func: Callable, ) -> Type[BaseModel]: """Create a pydantic schema from a function's signature.""" - inferred_model = validate_arguments(func).model # type: ignore + validated = validate_arguments(func, config=_SchemaConfig) # type: ignore + inferred_model = validated.model # type: ignore + if "run_manager" in inferred_model.__fields__: + del inferred_model.__fields__["run_manager"] # Pydantic adds placeholder virtual fields we need to strip filtered_args = get_filtered_args(inferred_model, func) return _create_subset_model( @@ -114,8 +131,11 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): """ verbose: bool = False """Whether to log the tool's progress.""" - callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager) - """Callback manager for this tool.""" + + callbacks: Callbacks = None + """Callbacks to be called during tool execution.""" + callback_manager: Optional[BaseCallbackManager] = None + """Deprecated. Please use callbacks instead.""" class Config: """Configuration for this pydantic object.""" @@ -133,8 +153,8 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): if self.args_schema is not None: return self.args_schema.schema()["properties"] else: - inferred_model = validate_arguments(self._run).model # type: ignore - return get_filtered_args(inferred_model, self._run) + schema = create_schema_from_function(self.name, self._run) + return schema.schema()["properties"] def _parse_input( self, @@ -150,23 +170,40 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): if input_args is not None: input_args.validate(tool_input) - @validator("callback_manager", pre=True, always=True) - def set_callback_manager( - cls, callback_manager: Optional[BaseCallbackManager] - ) -> BaseCallbackManager: - """If callback manager is None, set it. + @root_validator() + def raise_deprecation(cls, values: Dict) -> Dict: + """Raise deprecation warning if callback_manager is used.""" + if values.get("callback_manager") is not None: + warnings.warn( + "callback_manager is deprecated. Please use callbacks instead.", + DeprecationWarning, + ) + values["callbacks"] = values.pop("callback_manager", None) + return values - This allows users to pass in None as callback manager, which is a nice UX. + @abstractmethod + def _run( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool. + + Add run_manager: Optional[CallbackManagerForToolRun] = None + to child implementations to enable tracing, """ - return callback_manager or get_callback_manager() @abstractmethod - def _run(self, *args: Any, **kwargs: Any) -> Any: - """Use the tool.""" + async def _arun( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Use the tool asynchronously. - @abstractmethod - async def _arun(self, *args: Any, **kwargs: Any) -> Any: - """Use the tool asynchronously.""" + Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None + to child implementations to enable tracing, + """ def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: # For backwards compatibility, if run_input is a string, @@ -182,30 +219,37 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", + callbacks: Callbacks = None, **kwargs: Any, - ) -> str: + ) -> Any: """Run the tool.""" self._parse_input(tool_input) if not self.verbose and verbose is not None: verbose_ = verbose else: verbose_ = self.verbose - self.callback_manager.on_tool_start( + callback_manager = CallbackManager.configure( + callbacks, self.callbacks, verbose=verbose_ + ) + # TODO: maybe also pass through run_manager is _run supports kwargs + new_arg_supported = signature(self._run).parameters.get("run_manager") + run_manager = callback_manager.on_tool_start( {"name": self.name, "description": self.description}, tool_input if isinstance(tool_input, str) else str(tool_input), - verbose=verbose_, color=start_color, **kwargs, ) try: tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) - observation = self._run(*tool_args, **tool_kwargs) + observation = ( + self._run(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else self._run(*tool_args, **tool_kwargs) + ) except (Exception, KeyboardInterrupt) as e: - self.callback_manager.on_tool_error(e, verbose=verbose_) + run_manager.on_tool_error(e) raise e - self.callback_manager.on_tool_end( - str(observation), verbose=verbose_, color=color, name=self.name, **kwargs - ) + run_manager.on_tool_end(str(observation), color=color, name=self.name, **kwargs) return observation async def arun( @@ -214,6 +258,7 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): verbose: Optional[bool] = None, start_color: Optional[str] = "green", color: Optional[str] = "green", + callbacks: Callbacks = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" @@ -222,49 +267,35 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): verbose_ = verbose else: verbose_ = self.verbose - if self.callback_manager.is_async: - await self.callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), - verbose=verbose_, - color=start_color, - **kwargs, - ) - else: - self.callback_manager.on_tool_start( - {"name": self.name, "description": self.description}, - tool_input if isinstance(tool_input, str) else str(tool_input), - verbose=verbose_, - color=start_color, - **kwargs, - ) + callback_manager = AsyncCallbackManager.configure( + callbacks, self.callbacks, verbose=verbose_ + ) + new_arg_supported = signature(self._arun).parameters.get("run_manager") + run_manager = await callback_manager.on_tool_start( + {"name": self.name, "description": self.description}, + tool_input if isinstance(tool_input, str) else str(tool_input), + color=start_color, + **kwargs, + ) try: # We then call the tool on the tool input to get an observation tool_args, tool_kwargs = self._to_args_and_kwargs(tool_input) - observation = await self._arun(*tool_args, **tool_kwargs) + observation = ( + await self._arun(*tool_args, run_manager=run_manager, **tool_kwargs) + if new_arg_supported + else await self._arun(*tool_args, **tool_kwargs) + ) except (Exception, KeyboardInterrupt) as e: - if self.callback_manager.is_async: - await self.callback_manager.on_tool_error(e, verbose=verbose_) - else: - self.callback_manager.on_tool_error(e, verbose=verbose_) + await run_manager.on_tool_error(e) raise e - if self.callback_manager.is_async: - await self.callback_manager.on_tool_end( - str(observation), - verbose=verbose_, - color=color, - name=self.name, - **kwargs, - ) - else: - self.callback_manager.on_tool_end( - observation, verbose=verbose_, color=color, name=self.name, **kwargs - ) + await run_manager.on_tool_end( + str(observation), color=color, name=self.name, **kwargs + ) return observation - def __call__(self, tool_input: Union[str, dict]) -> Any: + def __call__(self, tool_input: str, callbacks: Callbacks = None) -> str: """Make tool callable.""" - return self.run(tool_input) + return self.run(tool_input, callbacks=callbacks) class StructuredTool(BaseTool): @@ -273,9 +304,9 @@ class StructuredTool(BaseTool): description: str = "" args_schema: Type[BaseModel] = Field(..., description="The tool schema.") """The input arguments' schema.""" - func: Callable[..., str] + func: Callable[..., Any] """The function to run when the tool is called.""" - coroutine: Optional[Callable[..., Awaitable[str]]] = None + coroutine: Optional[Callable[..., Awaitable[Any]]] = None """The asynchronous version of the function.""" @property @@ -283,14 +314,44 @@ class StructuredTool(BaseTool): """The tool's input arguments.""" return self.args_schema.schema()["properties"] - def _run(self, *args: Any, **kwargs: Any) -> Any: + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool.""" - return self.func(*args, **kwargs) + new_argument_supported = signature(self.func).parameters.get("callbacks") + return ( + self.func( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else self.func(*args, **kwargs) + ) - async def _arun(self, *args: Any, **kwargs: Any) -> Any: + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Use the tool asynchronously.""" if self.coroutine: - return await self.coroutine(*args, **kwargs) + new_argument_supported = signature(self.coroutine).parameters.get( + "callbacks" + ) + return ( + await self.coroutine( + *args, + callbacks=run_manager.get_child() if run_manager else None, + **kwargs, + ) + if new_argument_supported + else await self.coroutine(*args, **kwargs) + ) raise NotImplementedError("Tool does not support async") @classmethod diff --git a/langchain/tools/bing_search/tool.py b/langchain/tools/bing_search/tool.py index dd57295c..3340a55a 100644 --- a/langchain/tools/bing_search/tool.py +++ b/langchain/tools/bing_search/tool.py @@ -1,5 +1,11 @@ """Tool for the Bing search API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.bing_search import BingSearchAPIWrapper @@ -15,11 +21,19 @@ class BingSearchRun(BaseTool): ) api_wrapper: BingSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("BingSearchRun does not support async") @@ -36,10 +50,18 @@ class BingSearchResults(BaseTool): num_results: int = 4 api_wrapper: BingSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.api_wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("BingSearchResults does not support async") diff --git a/langchain/tools/ddg_search/tool.py b/langchain/tools/ddg_search/tool.py index 5948756f..109431e2 100644 --- a/langchain/tools/ddg_search/tool.py +++ b/langchain/tools/ddg_search/tool.py @@ -1,10 +1,14 @@ """Tool for the DuckDuckGo search API.""" import warnings -from typing import Any +from typing import Any, Optional from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper @@ -22,11 +26,19 @@ class DuckDuckGoSearchRun(BaseTool): default_factory=DuckDuckGoSearchAPIWrapper ) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("DuckDuckGoSearch does not support async") @@ -45,11 +57,19 @@ class DuckDuckGoSearchResults(BaseTool): default_factory=DuckDuckGoSearchAPIWrapper ) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.api_wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("DuckDuckGoSearchResults does not support async") diff --git a/langchain/tools/file_management/copy.py b/langchain/tools/file_management/copy.py index 0fb23c7e..5231c7d4 100644 --- a/langchain/tools/file_management/copy.py +++ b/langchain/tools/file_management/copy.py @@ -1,11 +1,16 @@ import shutil -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -17,12 +22,17 @@ class FileCopyInput(BaseModel): destination_path: str = Field(..., description="Path to save the copied file") -class CopyFileTool(BaseFileTool): +class CopyFileTool(BaseFileToolMixin, BaseTool): name: str = "copy_file" args_schema: Type[BaseModel] = FileCopyInput description: str = "Create a copy of a file in a specified location" - def _run(self, source_path: str, destination_path: str) -> str: + def _run( + self, + source_path: str, + destination_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: source_path_ = self.get_relative_path(source_path) except FileValidationError: @@ -41,6 +51,11 @@ class CopyFileTool(BaseFileTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, source_path: str, destination_path: str) -> str: + async def _arun( + self, + source_path: str, + destination_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/delete.py b/langchain/tools/file_management/delete.py index 218cf606..bf00e707 100644 --- a/langchain/tools/file_management/delete.py +++ b/langchain/tools/file_management/delete.py @@ -1,11 +1,16 @@ import os -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -16,12 +21,16 @@ class FileDeleteInput(BaseModel): file_path: str = Field(..., description="Path of the file to delete") -class DeleteFileTool(BaseFileTool): +class DeleteFileTool(BaseFileToolMixin, BaseTool): name: str = "file_delete" args_schema: Type[BaseModel] = FileDeleteInput description: str = "Delete a file" - def _run(self, file_path: str) -> str: + def _run( + self, + file_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: file_path_ = self.get_relative_path(file_path) except FileValidationError: @@ -34,6 +43,10 @@ class DeleteFileTool(BaseFileTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, file_path: str) -> str: + async def _arun( + self, + file_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/file_search.py b/langchain/tools/file_management/file_search.py index 7e2f1d93..ce67f59d 100644 --- a/langchain/tools/file_management/file_search.py +++ b/langchain/tools/file_management/file_search.py @@ -1,12 +1,17 @@ import fnmatch import os -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -24,14 +29,19 @@ class FileSearchInput(BaseModel): ) -class FileSearchTool(BaseFileTool): +class FileSearchTool(BaseFileToolMixin, BaseTool): name: str = "file_search" args_schema: Type[BaseModel] = FileSearchInput description: str = ( "Recursively search for files in a subdirectory that match the regex pattern" ) - def _run(self, pattern: str, dir_path: str = ".") -> str: + def _run( + self, + pattern: str, + dir_path: str = ".", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: dir_path_ = self.get_relative_path(dir_path) except FileValidationError: @@ -50,6 +60,11 @@ class FileSearchTool(BaseFileTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, dir_path: str, pattern: str) -> str: + async def _arun( + self, + dir_path: str, + pattern: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/list_dir.py b/langchain/tools/file_management/list_dir.py index ff5cb8a1..f013257d 100644 --- a/langchain/tools/file_management/list_dir.py +++ b/langchain/tools/file_management/list_dir.py @@ -1,11 +1,16 @@ import os -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -16,12 +21,16 @@ class DirectoryListingInput(BaseModel): dir_path: str = Field(default=".", description="Subdirectory to list.") -class ListDirectoryTool(BaseFileTool): +class ListDirectoryTool(BaseFileToolMixin, BaseTool): name: str = "list_directory" args_schema: Type[BaseModel] = DirectoryListingInput description: str = "List files and directories in a specified folder" - def _run(self, dir_path: str = ".") -> str: + def _run( + self, + dir_path: str = ".", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: dir_path_ = self.get_relative_path(dir_path) except FileValidationError: @@ -35,6 +44,10 @@ class ListDirectoryTool(BaseFileTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, dir_path: str) -> str: + async def _arun( + self, + dir_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/move.py b/langchain/tools/file_management/move.py index ccf88796..b4cc1a94 100644 --- a/langchain/tools/file_management/move.py +++ b/langchain/tools/file_management/move.py @@ -1,11 +1,16 @@ import shutil -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -17,12 +22,17 @@ class FileMoveInput(BaseModel): destination_path: str = Field(..., description="New path for the moved file") -class MoveFileTool(BaseFileTool): +class MoveFileTool(BaseFileToolMixin, BaseTool): name: str = "move_file" args_schema: Type[BaseModel] = FileMoveInput description: str = "Move or rename a file from one location to another" - def _run(self, source_path: str, destination_path: str) -> str: + def _run( + self, + source_path: str, + destination_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: source_path_ = self.get_relative_path(source_path) except FileValidationError: @@ -44,6 +54,11 @@ class MoveFileTool(BaseFileTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, source_path: str, destination_path: str) -> str: + async def _arun( + self, + source_path: str, + destination_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/read.py b/langchain/tools/file_management/read.py index d243a9e3..86d6191d 100644 --- a/langchain/tools/file_management/read.py +++ b/langchain/tools/file_management/read.py @@ -1,10 +1,15 @@ -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -15,12 +20,16 @@ class ReadFileInput(BaseModel): file_path: str = Field(..., description="name of file") -class ReadFileTool(BaseFileTool): +class ReadFileTool(BaseFileToolMixin, BaseTool): name: str = "read_file" args_schema: Type[BaseModel] = ReadFileInput description: str = "Read file from disk" - def _run(self, file_path: str) -> str: + def _run( + self, + file_path: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: read_path = self.get_relative_path(file_path) except FileValidationError: @@ -34,6 +43,10 @@ class ReadFileTool(BaseFileTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, file_path: str) -> str: + async def _arun( + self, + file_path: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/file_management/utils.py b/langchain/tools/file_management/utils.py index c8efefb4..788823fe 100644 --- a/langchain/tools/file_management/utils.py +++ b/langchain/tools/file_management/utils.py @@ -1,11 +1,9 @@ import sys from pathlib import Path -from typing import Any, Optional +from typing import Optional from pydantic import BaseModel -from langchain.tools.base import BaseTool - def is_relative_to(path: Path, root: Path) -> bool: """Check if path is relative to root.""" @@ -29,8 +27,8 @@ class FileValidationError(ValueError): """Error for paths outside the root directory.""" -class BaseFileTool(BaseTool, BaseModel): - """Input for ReadFileTool.""" +class BaseFileToolMixin(BaseModel): + """Mixin for file system tools.""" root_dir: Optional[str] = None """The final path will be chosen relative to root_dir if specified.""" @@ -41,12 +39,6 @@ class BaseFileTool(BaseTool, BaseModel): return Path(file_path) return get_validated_relative_path(Path(self.root_dir), file_path) - def _run(self, *args: Any, **kwargs: Any) -> str: - raise NotImplementedError - - async def _arun(self, *args: Any, **kwargs: Any) -> str: - raise NotImplementedError - def get_validated_relative_path(root: Path, user_path: str) -> Path: """Resolve a relative path, raising an error if not within the root directory.""" diff --git a/langchain/tools/file_management/write.py b/langchain/tools/file_management/write.py index 865bcbe7..fcebe1c7 100644 --- a/langchain/tools/file_management/write.py +++ b/langchain/tools/file_management/write.py @@ -1,10 +1,15 @@ -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.tools.base import BaseTool from langchain.tools.file_management.utils import ( INVALID_PATH_TEMPLATE, - BaseFileTool, + BaseFileToolMixin, FileValidationError, ) @@ -19,12 +24,18 @@ class WriteFileInput(BaseModel): ) -class WriteFileTool(BaseFileTool): +class WriteFileTool(BaseFileToolMixin, BaseTool): name: str = "write_file" args_schema: Type[BaseModel] = WriteFileInput description: str = "Write file to disk" - def _run(self, file_path: str, text: str, append: bool = False) -> str: + def _run( + self, + file_path: str, + text: str, + append: bool = False, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: try: write_path = self.get_relative_path(file_path) except FileValidationError: @@ -38,6 +49,12 @@ class WriteFileTool(BaseFileTool): except Exception as e: return "Error: " + str(e) - async def _arun(self, file_path: str, text: str, append: bool = False) -> str: + async def _arun( + self, + file_path: str, + text: str, + append: bool = False, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: # TODO: Add aiofiles method raise NotImplementedError diff --git a/langchain/tools/google_places/tool.py b/langchain/tools/google_places/tool.py index 31ae39da..27b0b561 100644 --- a/langchain/tools/google_places/tool.py +++ b/langchain/tools/google_places/tool.py @@ -1,11 +1,21 @@ """Tool for the Google search API.""" -from pydantic import Field +from typing import Optional +from pydantic import BaseModel, Field + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.google_places_api import GooglePlacesAPIWrapper +class GooglePlacesSchema(BaseModel): + query: str = Field(..., description="Query for goole maps") + + class GooglePlacesTool(BaseTool): """Tool that adds the capability to query the Google places API.""" @@ -18,10 +28,18 @@ class GooglePlacesTool(BaseTool): ) api_wrapper: GooglePlacesAPIWrapper = Field(default_factory=GooglePlacesAPIWrapper) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("GooglePlacesRun does not support async") diff --git a/langchain/tools/google_search/tool.py b/langchain/tools/google_search/tool.py index 1945a3df..71288e19 100644 --- a/langchain/tools/google_search/tool.py +++ b/langchain/tools/google_search/tool.py @@ -1,5 +1,11 @@ """Tool for the Google search API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.google_search import GoogleSearchAPIWrapper @@ -15,11 +21,19 @@ class GoogleSearchRun(BaseTool): ) api_wrapper: GoogleSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("GoogleSearchRun does not support async") @@ -36,10 +50,18 @@ class GoogleSearchResults(BaseTool): num_results: int = 4 api_wrapper: GoogleSearchAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.api_wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("GoogleSearchRun does not support async") diff --git a/langchain/tools/human/tool.py b/langchain/tools/human/tool.py index de2cce81..a207c6b1 100644 --- a/langchain/tools/human/tool.py +++ b/langchain/tools/human/tool.py @@ -1,9 +1,13 @@ """Tool for asking human input.""" -from typing import Callable +from typing import Callable, Optional from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -24,11 +28,19 @@ class HumanInputRun(BaseTool): prompt_func: Callable[[str], None] = Field(default_factory=lambda: _print_func) input_func: Callable = Field(default_factory=lambda: input) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Human input tool.""" self.prompt_func(query) return self.input_func() - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Human tool asynchronously.""" raise NotImplementedError("Human tool does not support async") diff --git a/langchain/tools/ifttt.py b/langchain/tools/ifttt.py index 8d3d943a..e42c232f 100644 --- a/langchain/tools/ifttt.py +++ b/langchain/tools/ifttt.py @@ -32,8 +32,14 @@ service, and you're ready to start receiving data and triggering actions 🎉 - Copy the IFTTT key value from there. The URL is of the form https://maker.ifttt.com/use/YOUR_IFTTT_KEY. Grab the YOUR_IFTTT_KEY value. """ +from typing import Optional + import requests +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -48,10 +54,18 @@ class IFTTTWebhook(BaseTool): url: str - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: body = {"this": tool_input} response = requests.post(self.url, data=body) return response.text - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("Not implemented.") diff --git a/langchain/tools/interaction/tool.py b/langchain/tools/interaction/tool.py index ee2b51ca..096c885d 100644 --- a/langchain/tools/interaction/tool.py +++ b/langchain/tools/interaction/tool.py @@ -1,6 +1,12 @@ """Tools for interacting with the user.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + RunManager, +) from langchain.tools.base import BaseTool @@ -14,9 +20,13 @@ class StdInInquireTool(BaseTool): " question (to disambiguate) or a request for more context." ) - def _run(self, prompt: str) -> str: + def _run(self, prompt: str, run_manager: Optional[RunManager] = None) -> str: """Prompt the user for more input.""" return input(f"\n{prompt}") - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support async") diff --git a/langchain/tools/jira/tool.py b/langchain/tools/jira/tool.py index 86861759..6c75ca91 100644 --- a/langchain/tools/jira/tool.py +++ b/langchain/tools/jira/tool.py @@ -28,8 +28,14 @@ agent = initialize_agent( ) ``` """ +from typing import Optional + from pydantic import Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.jira import JiraAPIWrapper @@ -40,10 +46,18 @@ class JiraAction(BaseTool): name = "" description = "" - def _run(self, instructions: str) -> str: + def _run( + self, + instructions: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Atlassian Jira API to run an operation.""" return self.api_wrapper.run(self.mode, instructions) - async def _arun(self, _: str) -> str: + async def _arun( + self, + _: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Atlassian Jira API to run an operation.""" raise NotImplementedError("JiraAction does not support async") diff --git a/langchain/tools/json/tool.py b/langchain/tools/json/tool.py index 9f1bdac7..6f6473d5 100644 --- a/langchain/tools/json/tool.py +++ b/langchain/tools/json/tool.py @@ -5,10 +5,14 @@ from __future__ import annotations import json import re from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -88,10 +92,18 @@ class JsonListKeysTool(BaseTool): """ spec: JsonSpec - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: return self.spec.keys(tool_input) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return self._run(tool_input) @@ -106,8 +118,16 @@ class JsonGetValueTool(BaseTool): """ spec: JsonSpec - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: return self.spec.value(tool_input) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return self._run(tool_input) diff --git a/langchain/tools/playwright/click.py b/langchain/tools/playwright/click.py index 601913ab..671faf43 100644 --- a/langchain/tools/playwright/click.py +++ b/langchain/tools/playwright/click.py @@ -1,9 +1,13 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( aget_current_page, @@ -22,7 +26,11 @@ class ClickTool(BaseBrowserTool): description: str = "Click on an element with the given CSS selector" args_schema: Type[BaseModel] = ClickToolInput - def _run(self, selector: str) -> str: + def _run( + self, + selector: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.sync_browser is None: raise ValueError(f"Synchronous browser not provided to {self.name}") @@ -31,7 +39,11 @@ class ClickTool(BaseBrowserTool): page.click(selector) return f"Clicked element '{selector}'" - async def _arun(self, selector: str) -> str: + async def _arun( + self, + selector: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.async_browser is None: raise ValueError(f"Asynchronous browser not provided to {self.name}") diff --git a/langchain/tools/playwright/current_page.py b/langchain/tools/playwright/current_page.py index 77b686cc..b0e51c25 100644 --- a/langchain/tools/playwright/current_page.py +++ b/langchain/tools/playwright/current_page.py @@ -1,14 +1,15 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel -from langchain.tools.playwright.base import BaseBrowserTool -from langchain.tools.playwright.utils import ( - aget_current_page, - get_current_page, +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, ) +from langchain.tools.playwright.base import BaseBrowserTool +from langchain.tools.playwright.utils import aget_current_page, get_current_page class CurrentWebPageTool(BaseBrowserTool): @@ -16,14 +17,20 @@ class CurrentWebPageTool(BaseBrowserTool): description: str = "Returns the URL of the current page" args_schema: Type[BaseModel] = BaseModel - def _run(self) -> str: + def _run( + self, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.sync_browser is None: raise ValueError(f"Synchronous browser not provided to {self.name}") page = get_current_page(self.sync_browser) return str(page.url) - async def _arun(self) -> str: + async def _arun( + self, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.async_browser is None: raise ValueError(f"Asynchronous browser not provided to {self.name}") diff --git a/langchain/tools/playwright/extract_hyperlinks.py b/langchain/tools/playwright/extract_hyperlinks.py index 4f088b31..03902fa5 100644 --- a/langchain/tools/playwright/extract_hyperlinks.py +++ b/langchain/tools/playwright/extract_hyperlinks.py @@ -1,10 +1,14 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any, Optional, Type from pydantic import BaseModel, Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import aget_current_page, get_current_page @@ -59,7 +63,11 @@ class ExtractHyperlinksTool(BaseBrowserTool): # Return the list of links as a JSON string return json.dumps(links) - def _run(self, absolute_urls: bool = False) -> str: + def _run( + self, + absolute_urls: bool = False, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.sync_browser is None: raise ValueError(f"Synchronous browser not provided to {self.name}") @@ -67,7 +75,11 @@ class ExtractHyperlinksTool(BaseBrowserTool): html_content = page.content() return self.scrape_page(page, html_content, absolute_urls) - async def _arun(self, absolute_urls: bool = False) -> str: + async def _arun( + self, + absolute_urls: bool = False, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" if self.async_browser is None: raise ValueError(f"Asynchronous browser not provided to {self.name}") diff --git a/langchain/tools/playwright/extract_text.py b/langchain/tools/playwright/extract_text.py index 9c7dacf9..5b228786 100644 --- a/langchain/tools/playwright/extract_text.py +++ b/langchain/tools/playwright/extract_text.py @@ -1,9 +1,13 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import aget_current_page, get_current_page @@ -25,7 +29,7 @@ class ExtractTextTool(BaseBrowserTool): ) return values - def _run(self) -> str: + def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: """Use the tool.""" # Use Beautiful Soup since it's faster than looping through the elements from bs4 import BeautifulSoup @@ -41,7 +45,9 @@ class ExtractTextTool(BaseBrowserTool): return " ".join(text for text in soup.stripped_strings) - async def _arun(self) -> str: + async def _arun( + self, run_manager: Optional[AsyncCallbackManagerForToolRun] = None + ) -> str: """Use the tool.""" if self.async_browser is None: raise ValueError(f"Asynchronous browser not provided to {self.name}") diff --git a/langchain/tools/playwright/get_elements.py b/langchain/tools/playwright/get_elements.py index 3a29e610..a5ad232f 100644 --- a/langchain/tools/playwright/get_elements.py +++ b/langchain/tools/playwright/get_elements.py @@ -5,6 +5,10 @@ from typing import TYPE_CHECKING, List, Optional, Sequence, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import aget_current_page, get_current_page @@ -73,7 +77,12 @@ class GetElementsTool(BaseBrowserTool): ) args_schema: Type[BaseModel] = GetElementsToolInput - def _run(self, selector: str, attributes: Sequence[str] = ["innerText"]) -> str: + def _run( + self, + selector: str, + attributes: Sequence[str] = ["innerText"], + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.sync_browser is None: raise ValueError(f"Synchronous browser not provided to {self.name}") @@ -83,7 +92,10 @@ class GetElementsTool(BaseBrowserTool): return json.dumps(results) async def _arun( - self, selector: str, attributes: Sequence[str] = ["innerText"] + self, + selector: str, + attributes: Sequence[str] = ["innerText"], + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" if self.async_browser is None: diff --git a/langchain/tools/playwright/navigate.py b/langchain/tools/playwright/navigate.py index ad596b03..f9af3fc8 100644 --- a/langchain/tools/playwright/navigate.py +++ b/langchain/tools/playwright/navigate.py @@ -1,9 +1,13 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( aget_current_page, @@ -22,7 +26,11 @@ class NavigateTool(BaseBrowserTool): description: str = "Navigate a browser to the specified URL" args_schema: Type[BaseModel] = NavigateToolInput - def _run(self, url: str) -> str: + def _run( + self, + url: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.sync_browser is None: raise ValueError(f"Synchronous browser not provided to {self.name}") @@ -31,7 +39,11 @@ class NavigateTool(BaseBrowserTool): status = response.status if response else "unknown" return f"Navigating to {url} returned status code {status}" - async def _arun(self, url: str) -> str: + async def _arun( + self, + url: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.async_browser is None: raise ValueError(f"Asynchronous browser not provided to {self.name}") diff --git a/langchain/tools/playwright/navigate_back.py b/langchain/tools/playwright/navigate_back.py index 5b613a81..da4d3577 100644 --- a/langchain/tools/playwright/navigate_back.py +++ b/langchain/tools/playwright/navigate_back.py @@ -1,9 +1,13 @@ from __future__ import annotations -from typing import Type +from typing import Optional, Type from pydantic import BaseModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.playwright.base import BaseBrowserTool from langchain.tools.playwright.utils import ( aget_current_page, @@ -18,7 +22,7 @@ class NavigateBackTool(BaseBrowserTool): description: str = "Navigate back to the previous page in the browser history" args_schema: Type[BaseModel] = BaseModel - def _run(self) -> str: + def _run(self, run_manager: Optional[CallbackManagerForToolRun] = None) -> str: """Use the tool.""" if self.sync_browser is None: raise ValueError(f"Synchronous browser not provided to {self.name}") @@ -33,7 +37,10 @@ class NavigateBackTool(BaseBrowserTool): else: return "Unable to navigate back; no previous page in the history" - async def _arun(self) -> str: + async def _arun( + self, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" if self.async_browser is None: raise ValueError(f"Asynchronous browser not provided to {self.name}") diff --git a/langchain/tools/plugin.py b/langchain/tools/plugin.py index 8f5fadd8..3d38895b 100644 --- a/langchain/tools/plugin.py +++ b/langchain/tools/plugin.py @@ -1,12 +1,16 @@ from __future__ import annotations import json -from typing import Optional +from typing import Optional, Type import requests import yaml from pydantic import BaseModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool @@ -45,9 +49,16 @@ def marshal_spec(txt: str) -> dict: return yaml.safe_load(txt) +class AIPLuginToolSchema(BaseModel): + """AIPLuginToolSchema.""" + + tool_input: Optional[str] = "" + + class AIPluginTool(BaseTool): plugin: AIPlugin api_spec: str + args_schema: Type[AIPLuginToolSchema] = AIPLuginToolSchema @classmethod def from_plugin_url(cls, url: str) -> AIPluginTool: @@ -72,10 +83,18 @@ class AIPluginTool(BaseTool): api_spec=api_spec, ) - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: Optional[str] = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.api_spec - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: Optional[str] = None, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return self.api_spec diff --git a/langchain/tools/powerbi/tool.py b/langchain/tools/powerbi/tool.py index 67efe423..633f99d3 100644 --- a/langchain/tools/powerbi/tool.py +++ b/langchain/tools/powerbi/tool.py @@ -3,6 +3,10 @@ from typing import Any, Dict, Optional from pydantic import Field, validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.chains.llm import LLMChain from langchain.tools.base import BaseTool from langchain.tools.powerbi.prompt import ( @@ -45,7 +49,11 @@ class QueryPowerBITool(BaseTool): self.session_cache[tool_input] = BAD_REQUEST_RESPONSE_ESCALATED return self.session_cache[tool_input] - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Execute the query, return the results or an error message.""" if cache := self._check_cache(tool_input): return cache @@ -67,7 +75,11 @@ class QueryPowerBITool(BaseTool): ) return self.session_cache[tool_input] - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Execute the query, return the results or an error message.""" if cache := self._check_cache(tool_input): return cache @@ -107,11 +119,19 @@ class InfoPowerBITool(BaseTool): arbitrary_types_allowed = True - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Get the schema for tables in a comma-separated list.""" return self.powerbi.get_table_info(tool_input.split(", ")) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return await self.powerbi.aget_table_info(tool_input.split(", ")) @@ -127,11 +147,21 @@ class ListPowerBITool(BaseTool): arbitrary_types_allowed = True - def _run(self, *args: Any, **kwargs: Any) -> str: + def _run( + self, + *args: Any, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Get the names of the tables.""" return ", ".join(self.powerbi.get_table_names()) - async def _arun(self, *args: Any, **kwargs: Any) -> str: + async def _arun( + self, + *args: Any, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> str: """Get the names of the tables.""" return ", ".join(self.powerbi.get_table_names()) @@ -171,7 +201,11 @@ class InputToQueryTool(BaseTool): ) return llm_chain - def _run(self, tool_input: str) -> str: + def _run( + self, + tool_input: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the LLM to check the query.""" return self.llm_chain.predict( tool_input=tool_input, @@ -180,7 +214,11 @@ class InputToQueryTool(BaseTool): examples=self.examples, ) - async def _arun(self, tool_input: str) -> str: + async def _arun( + self, + tool_input: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return await self.llm_chain.apredict( tool_input=tool_input, tables=self.powerbi.get_table_names(), diff --git a/langchain/tools/python/tool.py b/langchain/tools/python/tool.py index 607fec22..2e67d670 100644 --- a/langchain/tools/python/tool.py +++ b/langchain/tools/python/tool.py @@ -3,10 +3,14 @@ import ast import sys from io import StringIO -from typing import Dict, Optional +from typing import Any, Dict, Optional from pydantic import Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities import PythonREPL @@ -28,13 +32,23 @@ class PythonREPLTool(BaseTool): python_repl: PythonREPL = Field(default_factory=_get_default_python_repl) sanitize_input: bool = True - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool.""" if self.sanitize_input: query = query.strip().strip("```") return self.python_repl.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + **kwargs: Any, + ) -> Any: """Use the tool asynchronously.""" raise NotImplementedError("PythonReplTool does not support async") @@ -64,7 +78,11 @@ class PythonAstREPLTool(BaseTool): ) return values - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" try: if self.sanitize_input: @@ -91,6 +109,10 @@ class PythonAstREPLTool(BaseTool): except Exception as e: return "{}: {}".format(type(e).__name__, str(e)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("PythonReplTool does not support async") diff --git a/langchain/tools/requests/tool.py b/langchain/tools/requests/tool.py index 1bfc8bc7..64b25303 100644 --- a/langchain/tools/requests/tool.py +++ b/langchain/tools/requests/tool.py @@ -1,9 +1,13 @@ # flake8: noqa """Tools for making requests to an API endpoint.""" import json -from typing import Any, Dict +from typing import Any, Dict, Optional from pydantic import BaseModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.requests import TextRequestsWrapper from langchain.tools.base import BaseTool @@ -31,11 +35,17 @@ class RequestsGetTool(BaseRequestsTool, BaseTool): name = "requests_get" description = "A portal to the internet. Use this when you need to get specific content from a website. Input should be a url (i.e. https://www.google.com). The output will be the text response of the GET request." - def _run(self, url: str) -> str: + def _run( + self, url: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" return self.requests_wrapper.get(_clean_url(url)) - async def _arun(self, url: str) -> str: + async def _arun( + self, + url: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" return await self.requests_wrapper.aget(_clean_url(url)) @@ -52,7 +62,9 @@ class RequestsPostTool(BaseRequestsTool, BaseTool): The output will be the text response of the POST request. """ - def _run(self, text: str) -> str: + def _run( + self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" try: data = _parse_input(text) @@ -60,7 +72,11 @@ class RequestsPostTool(BaseRequestsTool, BaseTool): except Exception as e: return repr(e) - async def _arun(self, text: str) -> str: + async def _arun( + self, + text: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" try: data = _parse_input(text) @@ -83,7 +99,9 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool): The output will be the text response of the PATCH request. """ - def _run(self, text: str) -> str: + def _run( + self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" try: data = _parse_input(text) @@ -91,7 +109,11 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool): except Exception as e: return repr(e) - async def _arun(self, text: str) -> str: + async def _arun( + self, + text: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" try: data = _parse_input(text) @@ -114,7 +136,9 @@ class RequestsPutTool(BaseRequestsTool, BaseTool): The output will be the text response of the PUT request. """ - def _run(self, text: str) -> str: + def _run( + self, text: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Run the tool.""" try: data = _parse_input(text) @@ -122,7 +146,11 @@ class RequestsPutTool(BaseRequestsTool, BaseTool): except Exception as e: return repr(e) - async def _arun(self, text: str) -> str: + async def _arun( + self, + text: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" try: data = _parse_input(text) @@ -139,10 +167,18 @@ class RequestsDeleteTool(BaseRequestsTool, BaseTool): name = "requests_delete" description = "A portal to the internet. Use this when you need to make a DELETE request to a URL. Input should be a specific url, and the output will be the text response of the DELETE request." - def _run(self, url: str) -> str: + def _run( + self, + url: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Run the tool.""" return self.requests_wrapper.delete(_clean_url(url)) - async def _arun(self, url: str) -> str: + async def _arun( + self, + url: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run the tool asynchronously.""" return await self.requests_wrapper.adelete(_clean_url(url)) diff --git a/langchain/tools/scenexplain/tool.py b/langchain/tools/scenexplain/tool.py index d8e5394c..7ac3a72e 100644 --- a/langchain/tools/scenexplain/tool.py +++ b/langchain/tools/scenexplain/tool.py @@ -1,11 +1,22 @@ """Tool for the SceneXplain API.""" +from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.scenexplain import SceneXplainAPIWrapper +class SceneXplainInput(BaseModel): + """Input for SceneXplain.""" + + query: str = Field(..., description="The link to the image to explain") + + class SceneXplainTool(BaseTool): """Tool that adds the capability to explain images.""" @@ -17,10 +28,14 @@ class SceneXplainTool(BaseTool): ) api_wrapper: SceneXplainAPIWrapper = Field(default_factory=SceneXplainAPIWrapper) - def _run(self, query: str) -> str: + def _run( + self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Use the tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("SceneXplainTool does not support async") diff --git a/langchain/tools/searx_search/tool.py b/langchain/tools/searx_search/tool.py index a91f7e27..e3ea04b5 100644 --- a/langchain/tools/searx_search/tool.py +++ b/langchain/tools/searx_search/tool.py @@ -1,6 +1,12 @@ """Tool for the SearxNG search API.""" +from typing import Optional + from pydantic import Extra +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.searx_search import SearxSearchWrapper @@ -16,11 +22,19 @@ class SearxSearchRun(BaseTool): ) wrapper: SearxSearchWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return self.wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return await self.wrapper.arun(query) @@ -42,10 +56,18 @@ class SearxSearchResults(BaseTool): extra = Extra.allow - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" return str(self.wrapper.results(query, self.num_results)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" return (await self.wrapper.aresults(query, self.num_results)).__str__() diff --git a/langchain/tools/shell/tool.py b/langchain/tools/shell/tool.py index 8f9ecaef..42e19038 100644 --- a/langchain/tools/shell/tool.py +++ b/langchain/tools/shell/tool.py @@ -1,10 +1,14 @@ import asyncio import platform import warnings -from typing import List, Type +from typing import List, Optional, Type from pydantic import BaseModel, Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.bash import BashProcess @@ -60,11 +64,19 @@ class ShellTool(BaseTool): args_schema: Type[BaseModel] = ShellInput """Schema for input arguments.""" - def _run(self, commands: List[str]) -> str: + def _run( + self, + commands: List[str], + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Run commands and return final output.""" return self.process.run(commands) - async def _arun(self, commands: List[str]) -> str: + async def _arun( + self, + commands: List[str], + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Run commands asynchronously and return final output.""" return await asyncio.get_event_loop().run_in_executor( None, self.process.run, commands diff --git a/langchain/tools/sql_database/tool.py b/langchain/tools/sql_database/tool.py index d9d6cf63..2e677c6c 100644 --- a/langchain/tools/sql_database/tool.py +++ b/langchain/tools/sql_database/tool.py @@ -1,12 +1,17 @@ # flake8: noqa """Tools for interacting with a SQL database.""" -from pydantic import BaseModel, Extra, Field, validator, root_validator -from typing import Any, Dict +from typing import Any, Dict, Optional +from pydantic import BaseModel, Extra, Field, root_validator + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.chains.llm import LLMChain from langchain.prompts import PromptTemplate from langchain.sql_database import SQLDatabase -from langchain.schema import BaseLanguageModel from langchain.tools.base import BaseTool from langchain.tools.sql_database.prompt import QUERY_CHECKER @@ -35,11 +40,19 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool): If an error is returned, rewrite the query, check the query, and try again. """ - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Execute the query, return the results or an error message.""" return self.db.run_no_throw(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("QuerySqlDbTool does not support async") @@ -54,11 +67,19 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): Example Input: "table1, table2, table3" """ - def _run(self, table_names: str) -> str: + def _run( + self, + table_names: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Get the schema for tables in a comma-separated list.""" return self.db.get_table_info_no_throw(table_names.split(", ")) - async def _arun(self, table_name: str) -> str: + async def _arun( + self, + table_name: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("SchemaSqlDbTool does not support async") @@ -68,11 +89,19 @@ class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool): name = "list_tables_sql_db" description = "Input is an empty string, output is a comma separated list of tables in the database." - def _run(self, tool_input: str = "") -> str: + def _run( + self, + tool_input: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Get the schema for a specific table.""" return ", ".join(self.db.get_usable_table_names()) - async def _arun(self, tool_input: str = "") -> str: + async def _arun( + self, + tool_input: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: raise NotImplementedError("ListTablesSqlDbTool does not support async") @@ -106,9 +135,17 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool): return values - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the LLM to check the query.""" return self.llm_chain.predict(query=query, dialect=self.db.dialect) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: return await self.llm_chain.apredict(query=query, dialect=self.db.dialect) diff --git a/langchain/tools/vectorstore/tool.py b/langchain/tools/vectorstore/tool.py index 1dd18fd2..983224b4 100644 --- a/langchain/tools/vectorstore/tool.py +++ b/langchain/tools/vectorstore/tool.py @@ -1,10 +1,14 @@ """Tools for interacting with vectorstores.""" import json -from typing import Any, Dict +from typing import Any, Dict, Optional from pydantic import BaseModel, Field +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.chains import RetrievalQA, RetrievalQAWithSourcesChain from langchain.llms.base import BaseLLM from langchain.llms.openai import OpenAI @@ -42,14 +46,22 @@ class VectorStoreQATool(BaseVectorStoreTool, BaseTool): ) return template.format(name=name, description=description) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" chain = RetrievalQA.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return chain.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("VectorStoreQATool does not support async") @@ -70,13 +82,21 @@ class VectorStoreQAWithSourcesTool(BaseVectorStoreTool, BaseTool): ) return template.format(name=name, description=description) - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the tool.""" chain = RetrievalQAWithSourcesChain.from_chain_type( self.llm, retriever=self.vectorstore.as_retriever() ) return json.dumps(chain({chain.question_key: query}, return_only_outputs=True)) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the tool asynchronously.""" raise NotImplementedError("VectorStoreQAWithSourcesTool does not support async") diff --git a/langchain/tools/wikipedia/tool.py b/langchain/tools/wikipedia/tool.py index 5bede75b..af398d7f 100644 --- a/langchain/tools/wikipedia/tool.py +++ b/langchain/tools/wikipedia/tool.py @@ -1,5 +1,11 @@ """Tool for the Wikipedia API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.wikipedia import WikipediaAPIWrapper @@ -16,10 +22,18 @@ class WikipediaQueryRun(BaseTool): ) api_wrapper: WikipediaAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Wikipedia tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Wikipedia tool asynchronously.""" raise NotImplementedError("WikipediaQueryRun does not support async") diff --git a/langchain/tools/wolfram_alpha/tool.py b/langchain/tools/wolfram_alpha/tool.py index ecac7b8f..a243d22f 100644 --- a/langchain/tools/wolfram_alpha/tool.py +++ b/langchain/tools/wolfram_alpha/tool.py @@ -1,5 +1,11 @@ """Tool for the Wolfram Alpha API.""" +from typing import Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper @@ -16,10 +22,18 @@ class WolframAlphaQueryRun(BaseTool): ) api_wrapper: WolframAlphaAPIWrapper - def _run(self, query: str) -> str: + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the WolframAlpha tool.""" return self.api_wrapper.run(query) - async def _arun(self, query: str) -> str: + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the WolframAlpha tool asynchronously.""" raise NotImplementedError("WolframAlphaQueryRun does not support async") diff --git a/langchain/tools/zapier/tool.py b/langchain/tools/zapier/tool.py index f6ec020e..f68a3562 100644 --- a/langchain/tools/zapier/tool.py +++ b/langchain/tools/zapier/tool.py @@ -81,6 +81,10 @@ from typing import Any, Dict, Optional from pydantic import Field, root_validator +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) from langchain.tools.base import BaseTool from langchain.tools.zapier.prompt import BASE_ZAPIER_TOOL_PROMPT from langchain.utilities.zapier import ZapierNLAWrapper @@ -119,11 +123,17 @@ class ZapierNLARunAction(BaseTool): ) return values - def _run(self, instructions: str) -> str: + def _run( + self, instructions: str, run_manager: Optional[CallbackManagerForToolRun] = None + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" return self.api_wrapper.run_as_str(self.action_id, instructions, self.params) - async def _arun(self, _: str) -> str: + async def _arun( + self, + _: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" raise NotImplementedError("ZapierNLAListActions does not support async") @@ -148,11 +158,19 @@ class ZapierNLAListActions(BaseTool): ) api_wrapper: ZapierNLAWrapper = Field(default_factory=ZapierNLAWrapper) - def _run(self, _: str) -> str: + def _run( + self, + _: str = "", + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" return self.api_wrapper.list_as_str() - async def _arun(self, _: str) -> str: + async def _arun( + self, + _: str = "", + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> str: """Use the Zapier NLA tool to return a list of all exposed user actions.""" raise NotImplementedError("ZapierNLAListActions does not support async") diff --git a/langchain/utilities/serpapi.py b/langchain/utilities/serpapi.py index 0710c2f3..65042264 100644 --- a/langchain/utilities/serpapi.py +++ b/langchain/utilities/serpapi.py @@ -76,11 +76,11 @@ class SerpAPIWrapper(BaseModel): ) return values - async def arun(self, query: str) -> str: + async def arun(self, query: str, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" return self._process_response(await self.aresults(query)) - def run(self, query: str) -> str: + def run(self, query: str, **kwargs: Any) -> str: """Run query through SerpAPI and parse result.""" return self._process_response(self.results(query)) diff --git a/tests/integration_tests/callbacks/__init__.py b/tests/integration_tests/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py new file mode 100644 index 00000000..cd830ea2 --- /dev/null +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -0,0 +1,123 @@ +"""Integration tests for the langchain tracer module.""" +import asyncio +import os + +import pytest +from aiohttp import ClientSession + +from langchain.agents import AgentType, initialize_agent, load_tools +from langchain.callbacks import tracing_enabled +from langchain.llms import OpenAI + +questions = [ + ( + "Who won the US Open men's final in 2019? " + "What is his age raised to the 0.334 power?" + ), + ( + "Who is Olivia Wilde's boyfriend? " + "What is his current age raised to the 0.23 power?" + ), + ( + "Who won the most recent formula 1 grand prix? " + "What is their age raised to the 0.23 power?" + ), + ( + "Who won the US Open women's final in 2019? " + "What is her age raised to the 0.34 power?" + ), + ("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"), +] + + +def test_tracing_sequential() -> None: + os.environ["LANGCHAIN_TRACING"] = "true" + + for q in questions[:3]: + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent.run(q) + + +def test_tracing_session_env_var() -> None: + os.environ["LANGCHAIN_TRACING"] = "true" + os.environ["LANGCHAIN_SESSION"] = "my_session" + + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent.run(questions[0]) + if "LANGCHAIN_SESSION" in os.environ: + del os.environ["LANGCHAIN_SESSION"] + + +@pytest.mark.asyncio +async def test_tracing_concurrent() -> None: + os.environ["LANGCHAIN_TRACING"] = "true" + aiosession = ClientSession() + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + tasks = [agent.arun(q) for q in questions[:3]] + await asyncio.gather(*tasks) + await aiosession.close() + + +@pytest.mark.asyncio +async def test_tracing_concurrent_bw_compat_environ() -> None: + os.environ["LANGCHAIN_HANDLER"] = "langchain" + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + aiosession = ClientSession() + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + tasks = [agent.arun(q) for q in questions[:3]] + await asyncio.gather(*tasks) + await aiosession.close() + if "LANGCHAIN_HANDLER" in os.environ: + del os.environ["LANGCHAIN_HANDLER"] + + +def test_tracing_context_manager() -> None: + llm = OpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + with tracing_enabled() as session: + assert session + agent.run(questions[0]) # this should be traced + + agent.run(questions[0]) # this should not be traced + + +@pytest.mark.asyncio +async def test_tracing_context_manager_async() -> None: + llm = OpenAI(temperature=0) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + if "LANGCHAIN_TRACING" in os.environ: + del os.environ["LANGCHAIN_TRACING"] + + # start a background task + task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced + with tracing_enabled() as session: + assert session + tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced + await asyncio.gather(*tasks) + + await task diff --git a/tests/integration_tests/callbacks/test_openai_callback.py b/tests/integration_tests/callbacks/test_openai_callback.py new file mode 100644 index 00000000..9704cb56 --- /dev/null +++ b/tests/integration_tests/callbacks/test_openai_callback.py @@ -0,0 +1,55 @@ +"""Integration tests for the langchain tracer module.""" +import asyncio + +import pytest + +from langchain.agents import AgentType, initialize_agent, load_tools +from langchain.callbacks import get_openai_callback +from langchain.llms import OpenAI + + +@pytest.mark.asyncio +async def test_openai_callback() -> None: + llm = OpenAI(temperature=0) + with get_openai_callback() as cb: + llm("What is the square root of 4?") + + total_tokens = cb.total_tokens + assert total_tokens > 0 + + with get_openai_callback() as cb: + llm("What is the square root of 4?") + llm("What is the square root of 4?") + + assert cb.total_tokens == total_tokens * 2 + + with get_openai_callback() as cb: + await asyncio.gather( + *[llm.agenerate(["What is the square root of 4?"]) for _ in range(3)] + ) + + assert cb.total_tokens == total_tokens * 3 + + task = asyncio.create_task(llm.agenerate(["What is the square root of 4?"])) + with get_openai_callback() as cb: + await llm.agenerate(["What is the square root of 4?"]) + + await task + assert cb.total_tokens == total_tokens + + +def test_openai_callback_agent() -> None: + llm = OpenAI(temperature=0) + tools = load_tools(["serpapi", "llm-math"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + with get_openai_callback() as cb: + agent.run( + "Who is Olivia Wilde's boyfriend? " + "What is his current age raised to the 0.23 power?" + ) + print(f"Total Tokens: {cb.total_tokens}") + print(f"Prompt Tokens: {cb.prompt_tokens}") + print(f"Completion Tokens: {cb.completion_tokens}") + print(f"Total Cost (USD): ${cb.total_cost}") diff --git a/tests/integration_tests/chains/test_pal.py b/tests/integration_tests/chains/test_pal.py index 9bbf6f8d..cb03d80c 100644 --- a/tests/integration_tests/chains/test_pal.py +++ b/tests/integration_tests/chains/test_pal.py @@ -6,7 +6,7 @@ from langchain.chains.pal.base import PALChain def test_math_prompt() -> None: """Test math prompt.""" - llm = OpenAI(model_name="code-davinci-002", temperature=0, max_tokens=512) + llm = OpenAI(temperature=0, max_tokens=512) pal_chain = PALChain.from_math_prompt(llm) question = ( "Jan has three times the number of pets as Marcia. " @@ -19,7 +19,7 @@ def test_math_prompt() -> None: def test_colored_object_prompt() -> None: """Test colored object prompt.""" - llm = OpenAI(model_name="code-davinci-002", temperature=0, max_tokens=512) + llm = OpenAI(temperature=0, max_tokens=512) pal_chain = PALChain.from_colored_object_prompt(llm) question = ( "On the desk, you see two blue booklets, " diff --git a/tests/integration_tests/chains/test_sql_database.py b/tests/integration_tests/chains/test_sql_database.py index 3518866c..f19ec025 100644 --- a/tests/integration_tests/chains/test_sql_database.py +++ b/tests/integration_tests/chains/test_sql_database.py @@ -27,7 +27,7 @@ def test_sql_database_run() -> None: with engine.connect() as conn: conn.execute(stmt) db = SQLDatabase(engine) - db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db) + db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db) output = db_chain.run("What company does Harrison work at?") expected_output = " Harrison works at Foo." assert output == expected_output @@ -41,7 +41,7 @@ def test_sql_database_run_update() -> None: with engine.connect() as conn: conn.execute(stmt) db = SQLDatabase(engine) - db_chain = SQLDatabaseChain(llm=OpenAI(temperature=0), database=db) + db_chain = SQLDatabaseChain.from_llm(OpenAI(temperature=0), db) output = db_chain.run("Update Harrison's workplace to Bar") expected_output = " Harrison's workplace has been updated to Bar." assert output == expected_output @@ -59,9 +59,7 @@ def test_sql_database_sequential_chain_run() -> None: with engine.connect() as conn: conn.execute(stmt) db = SQLDatabase(engine) - db_chain = SQLDatabaseSequentialChain.from_llm( - llm=OpenAI(temperature=0), database=db - ) + db_chain = SQLDatabaseSequentialChain.from_llm(OpenAI(temperature=0), db) output = db_chain.run("What company does Harrison work at?") expected_output = " Harrison works at Foo." assert output == expected_output @@ -77,7 +75,7 @@ def test_sql_database_sequential_chain_intermediate_steps() -> None: conn.execute(stmt) db = SQLDatabase(engine) db_chain = SQLDatabaseSequentialChain.from_llm( - llm=OpenAI(temperature=0), database=db, return_intermediate_steps=True + OpenAI(temperature=0), db, return_intermediate_steps=True ) output = db_chain("What company does Harrison work at?") expected_output = " Harrison works at Foo." diff --git a/tests/integration_tests/chat_models/test_anthropic.py b/tests/integration_tests/chat_models/test_anthropic.py index 60fe58f3..a7186c8b 100644 --- a/tests/integration_tests/chat_models/test_anthropic.py +++ b/tests/integration_tests/chat_models/test_anthropic.py @@ -3,7 +3,7 @@ from typing import List import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.chat_models.anthropic import ChatAnthropic from langchain.schema import ( AIMessage, diff --git a/tests/integration_tests/chat_models/test_openai.py b/tests/integration_tests/chat_models/test_openai.py index 06394ebc..432c2c88 100644 --- a/tests/integration_tests/chat_models/test_openai.py +++ b/tests/integration_tests/chat_models/test_openai.py @@ -3,7 +3,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.chat_models.openai import ChatOpenAI from langchain.schema import ( BaseMessage, diff --git a/tests/integration_tests/chat_models/test_promptlayer_openai.py b/tests/integration_tests/chat_models/test_promptlayer_openai.py index c9962f75..ab68a085 100644 --- a/tests/integration_tests/chat_models/test_promptlayer_openai.py +++ b/tests/integration_tests/chat_models/test_promptlayer_openai.py @@ -2,7 +2,7 @@ import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.schema import ( BaseMessage, diff --git a/tests/integration_tests/llms/test_anthropic.py b/tests/integration_tests/llms/test_anthropic.py index 2e81f297..1d0a9475 100644 --- a/tests/integration_tests/llms/test_anthropic.py +++ b/tests/integration_tests/llms/test_anthropic.py @@ -3,7 +3,7 @@ from typing import Generator import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.llms.anthropic import Anthropic from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler diff --git a/tests/integration_tests/llms/test_llamacpp.py b/tests/integration_tests/llms/test_llamacpp.py index 7ea2881f..e1a28594 100644 --- a/tests/integration_tests/llms/test_llamacpp.py +++ b/tests/integration_tests/llms/test_llamacpp.py @@ -5,7 +5,6 @@ from typing import Generator from urllib.request import urlretrieve from langchain.llms import LlamaCpp -from langchain.callbacks.base import CallbackManager from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -61,10 +60,9 @@ def test_llamacpp_streaming_callback() -> None: OFF_BY_ONE = 1 # There may be an off by one error in the upstream code! callback_handler = FakeCallbackHandler() - callback_manager = CallbackManager([callback_handler]) llm = LlamaCpp( model_path=get_model(), - callback_manager=callback_manager, + callbacks=[callback_handler], verbose=True, max_tokens=MAX_TOKENS, ) diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 9db120a5..e10a9c0b 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -5,7 +5,7 @@ from typing import Generator import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManager from langchain.llms.loading import load_llm from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult diff --git a/tests/integration_tests/test_schema.py b/tests/integration_tests/test_schema.py index 472c27cc..18bb57ab 100644 --- a/tests/integration_tests/test_schema.py +++ b/tests/integration_tests/test_schema.py @@ -1,6 +1,6 @@ """Test formatting functionality.""" -from langchain.schema import _get_num_tokens_default_method +from langchain.base_language import _get_num_tokens_default_method class TestTokenCountingWithGPT2Tokenizer: diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index f30b3f05..3a03f03f 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -4,7 +4,7 @@ from typing import Any, List, Mapping, Optional from langchain.agents import AgentExecutor, AgentType, initialize_agent from langchain.agents.tools import Tool -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -15,7 +15,12 @@ class FakeListLLM(LLM): responses: List[str] i: int = -1 - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Increment counter, and then return response in that index.""" self.i += 1 print(f"=== Mock Response #{self.i} ===") @@ -82,135 +87,57 @@ def test_agent_stopped_early() -> None: assert output == "Agent stopped due to iteration limit or time limit." -def test_agent_with_callbacks_global() -> None: +def test_agent_with_callbacks() -> None: """Test react chain with callbacks by setting verbose globally.""" - import langchain + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() - langchain.verbose = True - handler = FakeCallbackHandler() - manager = CallbackManager(handlers=[handler]) tool = "Search" responses = [ f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", "Oh well\nFinal Answer: curses foiled again", ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) + # Only fake LLM gets callbacks for handler2 + fake_llm = FakeListLLM(responses=responses, callbacks=[handler2]) tools = [ Tool( name="Search", func=lambda x: x, description="Useful for searching", - callback_manager=manager, ), ] agent = initialize_agent( tools, fake_llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - callback_manager=manager, ) - output = agent.run("when was langchain made") + output = agent.run("when was langchain made", callbacks=[handler1]) assert output == "curses foiled again" # 1 top level chain run runs, 2 LLMChain runs, 2 LLM runs, 1 tool run - assert handler.chain_starts == handler.chain_ends == 3 - assert handler.llm_starts == handler.llm_ends == 2 - assert handler.tool_starts == 2 - assert handler.tool_ends == 1 + assert handler1.chain_starts == handler1.chain_ends == 3 + assert handler1.llm_starts == handler1.llm_ends == 2 + assert handler1.tool_starts == 1 + assert handler1.tool_ends == 1 # 1 extra agent action - assert handler.starts == 7 + assert handler1.starts == 7 # 1 extra agent end - assert handler.ends == 7 - assert handler.errors == 0 + assert handler1.ends == 7 + assert handler1.errors == 0 # during LLMChain - assert handler.text == 2 + assert handler1.text == 2 - -def test_agent_with_callbacks_local() -> None: - """Test react chain with callbacks by setting verbose locally.""" - import langchain - - langchain.verbose = False - handler = FakeCallbackHandler() - manager = CallbackManager(handlers=[handler]) - tool = "Search" - responses = [ - f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", - "Oh well\nFinal Answer: curses foiled again", - ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager, verbose=True) - tools = [ - Tool( - name="Search", - func=lambda x: x, - description="Useful for searching", - callback_manager=manager, - ), - ] - agent = initialize_agent( - tools, - fake_llm, - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - verbose=True, - callback_manager=manager, + assert handler2.llm_starts == 2 + assert handler2.llm_ends == 2 + assert ( + handler2.chain_starts + == handler2.tool_starts + == handler2.tool_ends + == handler2.chain_ends + == 0 ) - agent.agent.llm_chain.verbose = True # type: ignore - - output = agent.run("when was langchain made") - assert output == "curses foiled again" - - # 1 top level chain run, 2 LLMChain starts, 2 LLM runs, 1 tool run - assert handler.chain_starts == handler.chain_ends == 3 - assert handler.llm_starts == handler.llm_ends == 2 - assert handler.tool_starts == 2 - assert handler.tool_ends == 1 - # 1 extra agent action - assert handler.starts == 7 - # 1 extra agent end - assert handler.ends == 7 - assert handler.errors == 0 - # during LLMChain - assert handler.text == 2 - - -def test_agent_with_callbacks_not_verbose() -> None: - """Test react chain with callbacks but not verbose.""" - import langchain - - langchain.verbose = False - handler = FakeCallbackHandler() - manager = CallbackManager(handlers=[handler]) - tool = "Search" - responses = [ - f"FooBarBaz\nAction: {tool}\nAction Input: misalignment", - "Oh well\nFinal Answer: curses foiled again", - ] - fake_llm = FakeListLLM(responses=responses, callback_manager=manager) - tools = [ - Tool( - name="Search", - func=lambda x: x, - description="Useful for searching", - ), - ] - agent = initialize_agent( - tools, - fake_llm, - agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, - callback_manager=manager, - ) - - output = agent.run("when was langchain made") - assert output == "curses foiled again" - - # 1 top level chain run, 2 LLMChain runs, 2 LLM runs, 1 tool run - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 - def test_agent_tool_return_direct() -> None: """Test agent using tools that return directly.""" diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index 2689ea36..8f2a3ff2 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -4,6 +4,7 @@ from typing import Any, List, Mapping, Optional, Union from langchain.agents.react.base import ReActChain, ReActDocstoreAgent from langchain.agents.tools import Tool +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.docstore.base import Docstore from langchain.docstore.document import Document from langchain.llms.base import LLM @@ -32,7 +33,12 @@ class FakeListLLM(LLM): """Return type of llm.""" return "fake_list" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Increment counter, and then return response in that index.""" self.i += 1 return self.responses[self.i] diff --git a/tests/unit_tests/agents/test_tools.py b/tests/unit_tests/agents/test_tools.py index 965d44db..055720d8 100644 --- a/tests/unit_tests/agents/test_tools.py +++ b/tests/unit_tests/agents/test_tools.py @@ -171,7 +171,7 @@ def test_decorated_function_schema_equivalent() -> None: def test_structured_args_decorator_no_infer_schema() -> None: """Test functionality with structured arguments parsed as a decorator.""" - @tool + @tool(infer_schema=False) def structured_tool_input( arg1: int, arg2: Union[float, datetime], opt_arg: Optional[dict] = None ) -> str: @@ -182,7 +182,8 @@ def test_structured_args_decorator_no_infer_schema() -> None: assert structured_tool_input.name == "structured_tool_input" args = {"arg1": 1, "arg2": 0.001, "opt_arg": {"foo": "bar"}} expected_result = "1, 0.001, {'foo': 'bar'}" - assert structured_tool_input.run(args) == expected_result + with pytest.raises(ValueError): + assert structured_tool_input.run(args) == expected_result def test_structured_single_str_decorator_no_infer_schema() -> None: diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index 921596e7..ef2b8171 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -1,10 +1,9 @@ """A fake callback handler for testing purposes.""" -from typing import Any, Dict, List, Union +from typing import Any from pydantic import BaseModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult class BaseFakeCallbackHandler(BaseModel): @@ -17,12 +16,72 @@ class BaseFakeCallbackHandler(BaseModel): ignore_llm_: bool = False ignore_chain_: bool = False ignore_agent_: bool = False - always_verbose_: bool = False - @property - def always_verbose(self) -> bool: - """Whether to call verbose callbacks even if verbose is False.""" - return self.always_verbose_ + # add finer-grained counters for easier debugging of failing tests + chain_starts: int = 0 + chain_ends: int = 0 + llm_starts: int = 0 + llm_ends: int = 0 + llm_streams: int = 0 + tool_starts: int = 0 + tool_ends: int = 0 + agent_actions: int = 0 + agent_ends: int = 0 + + +class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): + """Base fake callback handler mixin for testing.""" + + def on_llm_start_common(self) -> None: + self.llm_starts += 1 + self.starts += 1 + + def on_llm_end_common(self) -> None: + self.llm_ends += 1 + self.ends += 1 + + def on_llm_error_common(self) -> None: + self.errors += 1 + + def on_llm_new_token_common(self) -> None: + self.llm_streams += 1 + + def on_chain_start_common(self) -> None: + self.chain_starts += 1 + self.starts += 1 + + def on_chain_end_common(self) -> None: + self.chain_ends += 1 + self.ends += 1 + + def on_chain_error_common(self) -> None: + self.errors += 1 + + def on_tool_start_common(self) -> None: + self.tool_starts += 1 + self.starts += 1 + + def on_tool_end_common(self) -> None: + self.tool_ends += 1 + self.ends += 1 + + def on_tool_error_common(self) -> None: + self.errors += 1 + + def on_agent_action_common(self) -> None: + self.agent_actions += 1 + self.starts += 1 + + def on_agent_finish_common(self) -> None: + self.agent_ends += 1 + self.ends += 1 + + def on_text_common(self) -> None: + self.text += 1 + + +class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): + """Fake callback handler for testing.""" @property def ignore_llm(self) -> bool: @@ -39,164 +98,209 @@ class BaseFakeCallbackHandler(BaseModel): """Whether to ignore agent callbacks.""" return self.ignore_agent_ - # add finer-grained counters for easier debugging of failing tests - chain_starts: int = 0 - chain_ends: int = 0 - llm_starts: int = 0 - llm_ends: int = 0 - llm_streams: int = 0 - tool_starts: int = 0 - tool_ends: int = 0 - agent_ends: int = 0 - - -class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler): - """Fake callback handler for testing.""" - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running.""" - self.llm_starts += 1 - self.starts += 1 + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_start_common() - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run when LLM generates a new token.""" - self.llm_streams += 1 + def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_new_token_common() - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - self.llm_ends += 1 - self.ends += 1 + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_end_common() def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" - self.errors += 1 + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_llm_error_common() def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> None: - """Run when chain starts running.""" - self.chain_starts += 1 - self.starts += 1 + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_start_common() - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - self.chain_ends += 1 - self.ends += 1 + def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_end_common() def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when chain errors.""" - self.errors += 1 + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_chain_error_common() def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> None: - """Run when tool starts running.""" - self.tool_starts += 1 - self.starts += 1 + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_start_common() - def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - self.tool_ends += 1 - self.ends += 1 + def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_end_common() def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when tool errors.""" - self.errors += 1 + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_tool_error_common() - def on_text(self, text: str, **kwargs: Any) -> None: - """Run when agent is ending.""" - self.text += 1 + def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_action_common() - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run when agent ends running.""" - self.agent_ends += 1 - self.ends += 1 + def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_agent_finish_common() - def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - """Run on agent action.""" - self.tool_starts += 1 - self.starts += 1 + def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_text_common() + + def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": + return self -class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler): +class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): """Fake async callback handler for testing.""" + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return self.ignore_llm_ + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return self.ignore_chain_ + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return self.ignore_agent_ + async def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when LLM starts running.""" - self.llm_starts += 1 - self.starts += 1 + self.on_llm_start_common() - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run when LLM generates a new token.""" - self.llm_streams += 1 + async def on_llm_new_token( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_new_token_common() - async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - self.llm_ends += 1 - self.ends += 1 + async def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_llm_end_common() async def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when LLM errors.""" - self.errors += 1 + self.on_llm_error_common() async def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when chain starts running.""" - self.chain_starts += 1 - self.starts += 1 + self.on_chain_start_common() - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: - """Run when chain ends running.""" - self.chain_ends += 1 - self.ends += 1 + async def on_chain_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_chain_end_common() async def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when chain errors.""" - self.errors += 1 + self.on_chain_error_common() async def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when tool starts running.""" - self.tool_starts += 1 - self.starts += 1 + self.on_tool_start_common() - async def on_tool_end(self, output: str, **kwargs: Any) -> None: - """Run when tool ends running.""" - self.tool_ends += 1 - self.ends += 1 + async def on_tool_end( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_tool_end_common() async def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + *args: Any, + **kwargs: Any, ) -> None: - """Run when tool errors.""" - self.errors += 1 + self.on_tool_error_common() - async def on_text(self, text: str, **kwargs: Any) -> None: - """Run when agent is ending.""" - self.text += 1 + async def on_agent_action( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_action_common() - async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: - """Run when agent ends running.""" - self.agent_ends += 1 - self.ends += 1 + async def on_agent_finish( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_agent_finish_common() - async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None: - """Run on agent action.""" - self.tool_starts += 1 - self.starts += 1 + async def on_text( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self.on_text_common() + + def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": + return self diff --git a/tests/unit_tests/callbacks/test_callback_manager.py b/tests/unit_tests/callbacks/test_callback_manager.py index 0f61fdd3..6a215985 100644 --- a/tests/unit_tests/callbacks/test_callback_manager.py +++ b/tests/unit_tests/callbacks/test_callback_manager.py @@ -1,15 +1,12 @@ """Test CallbackManager.""" -from typing import Tuple +from typing import List, Tuple import pytest -from langchain.callbacks.base import ( - AsyncCallbackManager, - BaseCallbackManager, - CallbackManager, -) -from langchain.callbacks.shared import SharedCallbackManager -from langchain.schema import AgentFinish, LLMResult +from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.manager import AsyncCallbackManager, CallbackManager +from langchain.callbacks.stdout import StdOutCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( BaseFakeCallbackHandler, FakeAsyncCallbackHandler, @@ -18,19 +15,26 @@ from tests.unit_tests.callbacks.fake_callback_handler import ( def _test_callback_manager( - manager: BaseCallbackManager, *handlers: BaseFakeCallbackHandler + manager: CallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - manager.on_llm_start({}, []) - manager.on_llm_end(LLMResult(generations=[])) - manager.on_llm_error(Exception()) - manager.on_chain_start({"name": "foo"}, {}) - manager.on_chain_end({}) - manager.on_chain_error(Exception()) - manager.on_tool_start({}, "") - manager.on_tool_end("") - manager.on_tool_error(Exception()) - manager.on_agent_finish(AgentFinish(log="", return_values={})) + run_manager = manager.on_llm_start({}, []) + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) + run_manager.on_llm_new_token("foo") + run_manager.on_text("foo") + + run_manager_chain = manager.on_chain_start({"name": "foo"}, {}) + run_manager_chain.on_chain_end({}) + run_manager_chain.on_chain_error(Exception()) + run_manager_chain.on_agent_action(AgentAction(tool_input="foo", log="", tool="")) + run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={})) + run_manager_chain.on_text("foo") + + run_manager_tool = manager.on_tool_start({}, "") + run_manager_tool.on_tool_end("") + run_manager_tool.on_tool_error(Exception()) + run_manager_tool.on_text("foo") _check_num_calls(handlers) @@ -38,75 +42,62 @@ async def _test_callback_manager_async( manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler ) -> None: """Test the CallbackManager.""" - await manager.on_llm_start({}, []) - await manager.on_llm_end(LLMResult(generations=[])) - await manager.on_llm_error(Exception()) - await manager.on_chain_start({"name": "foo"}, {}) - await manager.on_chain_end({}) - await manager.on_chain_error(Exception()) - await manager.on_tool_start({}, "") - await manager.on_tool_end("") - await manager.on_tool_error(Exception()) - await manager.on_agent_finish(AgentFinish(log="", return_values={})) + run_manager = await manager.on_llm_start({}, []) + await run_manager.on_llm_end(LLMResult(generations=[])) + await run_manager.on_llm_error(Exception()) + await run_manager.on_llm_new_token("foo") + await run_manager.on_text("foo") + + run_manager_chain = await manager.on_chain_start({"name": "foo"}, {}) + await run_manager_chain.on_chain_end({}) + await run_manager_chain.on_chain_error(Exception()) + await run_manager_chain.on_agent_action( + AgentAction(tool_input="foo", log="", tool="") + ) + await run_manager_chain.on_agent_finish(AgentFinish(log="", return_values={})) + await run_manager_chain.on_text("foo") + + run_manager_tool = await manager.on_tool_start({}, "") + await run_manager_tool.on_tool_end("") + await run_manager_tool.on_tool_error(Exception()) + await run_manager_tool.on_text("foo") _check_num_calls(handlers) def _check_num_calls(handlers: Tuple[BaseFakeCallbackHandler, ...]) -> None: for handler in handlers: - if handler.always_verbose: - assert handler.starts == 3 - assert handler.ends == 4 - assert handler.errors == 3 - else: - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 - - -def _test_callback_manager_pass_in_verbose( - manager: BaseCallbackManager, *handlers: FakeCallbackHandler -) -> None: - """Test the CallbackManager.""" - manager.on_llm_start({}, [], verbose=True) - manager.on_llm_end(LLMResult(generations=[]), verbose=True) - manager.on_llm_error(Exception(), verbose=True) - manager.on_chain_start({"name": "foo"}, {}, verbose=True) - manager.on_chain_end({}, verbose=True) - manager.on_chain_error(Exception(), verbose=True) - manager.on_tool_start({}, "", verbose=True) - manager.on_tool_end("", verbose=True) - manager.on_tool_error(Exception(), verbose=True) - manager.on_agent_finish(AgentFinish(log="", return_values={}), verbose=True) - for handler in handlers: - assert handler.starts == 3 + assert handler.starts == 4 assert handler.ends == 4 assert handler.errors == 3 + assert handler.text == 3 + + assert handler.llm_starts == 1 + assert handler.llm_ends == 1 + assert handler.llm_streams == 1 + + assert handler.chain_starts == 1 + assert handler.chain_ends == 1 + + assert handler.tool_starts == 1 + assert handler.tool_ends == 1 def test_callback_manager() -> None: """Test the CallbackManager.""" - handler1 = FakeCallbackHandler(always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=False) + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() manager = CallbackManager([handler1, handler2]) _test_callback_manager(manager, handler1, handler2) -def test_callback_manager_pass_in_verbose() -> None: - """Test the CallbackManager.""" - handler1 = FakeCallbackHandler() - handler2 = FakeCallbackHandler() - manager = CallbackManager([handler1, handler2]) - _test_callback_manager_pass_in_verbose(manager, handler1, handler2) - - def test_ignore_llm() -> None: """Test ignore llm param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_llm_=True, always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler(ignore_llm_=True) + handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_llm_start({}, [], verbose=True) - manager.on_llm_end(LLMResult(generations=[]), verbose=True) - manager.on_llm_error(Exception(), verbose=True) + run_manager = manager.on_llm_start({}, []) + run_manager.on_llm_end(LLMResult(generations=[])) + run_manager.on_llm_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 @@ -117,12 +108,12 @@ def test_ignore_llm() -> None: def test_ignore_chain() -> None: """Test ignore chain param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_chain_=True, always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler(ignore_chain_=True) + handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_chain_start({"name": "foo"}, {}, verbose=True) - manager.on_chain_end({}, verbose=True) - manager.on_chain_error(Exception(), verbose=True) + run_manager = manager.on_chain_start({"name": "foo"}, {}) + run_manager.on_chain_end({}) + run_manager.on_chain_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 @@ -133,39 +124,24 @@ def test_ignore_chain() -> None: def test_ignore_agent() -> None: """Test ignore agent param for callback handlers.""" - handler1 = FakeCallbackHandler(ignore_agent_=True, always_verbose_=True) - handler2 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler(ignore_agent_=True) + handler2 = FakeCallbackHandler() manager = CallbackManager(handlers=[handler1, handler2]) - manager.on_tool_start({}, "", verbose=True) - manager.on_tool_end("", verbose=True) - manager.on_tool_error(Exception(), verbose=True) - manager.on_agent_finish(AgentFinish({}, ""), verbose=True) + run_manager = manager.on_tool_start({}, "") + run_manager.on_tool_end("") + run_manager.on_tool_error(Exception()) assert handler1.starts == 0 assert handler1.ends == 0 assert handler1.errors == 0 assert handler2.starts == 1 - assert handler2.ends == 2 + assert handler2.ends == 1 assert handler2.errors == 1 -def test_shared_callback_manager() -> None: - """Test the SharedCallbackManager.""" - manager1 = SharedCallbackManager() - manager2 = SharedCallbackManager() - - assert manager1 is manager2 - - handler1 = FakeCallbackHandler(always_verbose_=True) - handler2 = FakeCallbackHandler() - manager1.add_handler(handler1) - manager2.add_handler(handler2) - _test_callback_manager(manager1, handler1, handler2) - - @pytest.mark.asyncio async def test_async_callback_manager() -> None: """Test the AsyncCallbackManager.""" - handler1 = FakeAsyncCallbackHandler(always_verbose_=True) + handler1 = FakeAsyncCallbackHandler() handler2 = FakeAsyncCallbackHandler() manager = AsyncCallbackManager([handler1, handler2]) await _test_callback_manager_async(manager, handler1, handler2) @@ -174,8 +150,95 @@ async def test_async_callback_manager() -> None: @pytest.mark.asyncio async def test_async_callback_manager_sync_handler() -> None: """Test the AsyncCallbackManager.""" - handler1 = FakeCallbackHandler(always_verbose_=True) + handler1 = FakeCallbackHandler() handler2 = FakeAsyncCallbackHandler() - handler3 = FakeAsyncCallbackHandler(always_verbose_=True) + handler3 = FakeAsyncCallbackHandler() manager = AsyncCallbackManager([handler1, handler2, handler3]) await _test_callback_manager_async(manager, handler1, handler2, handler3) + + +def test_callback_manager_inheritance() -> None: + handler1, handler2, handler3, handler4 = ( + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + ) + + callback_manager1 = CallbackManager([handler1, handler2]) + assert callback_manager1.handlers == [handler1, handler2] + assert callback_manager1.inheritable_handlers == [] + + callback_manager2 = CallbackManager([]) + assert callback_manager2.handlers == [] + assert callback_manager2.inheritable_handlers == [] + + callback_manager2.set_handlers([handler1, handler2]) + assert callback_manager2.handlers == [handler1, handler2] + assert callback_manager2.inheritable_handlers == [handler1, handler2] + + callback_manager2.set_handlers([handler3, handler4], inherit=False) + assert callback_manager2.handlers == [handler3, handler4] + assert callback_manager2.inheritable_handlers == [] + + callback_manager2.add_handler(handler1) + assert callback_manager2.handlers == [handler3, handler4, handler1] + assert callback_manager2.inheritable_handlers == [handler1] + + callback_manager2.add_handler(handler2, inherit=False) + assert callback_manager2.handlers == [handler3, handler4, handler1, handler2] + assert callback_manager2.inheritable_handlers == [handler1] + + run_manager = callback_manager2.on_chain_start({"name": "foo"}, {}) + child_manager = run_manager.get_child() + assert child_manager.handlers == [handler1] + assert child_manager.inheritable_handlers == [handler1] + + run_manager_tool = child_manager.on_tool_start({}, "") + assert run_manager_tool.handlers == [handler1] + assert run_manager_tool.inheritable_handlers == [handler1] + + child_manager2 = run_manager_tool.get_child() + assert child_manager2.handlers == [handler1] + assert child_manager2.inheritable_handlers == [handler1] + + +def test_callback_manager_configure() -> None: + """Test callback manager configuration.""" + handler1, handler2, handler3, handler4 = ( + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + FakeCallbackHandler(), + ) + + inheritable_callbacks: List[BaseCallbackHandler] = [handler1, handler2] + local_callbacks: List[BaseCallbackHandler] = [handler3, handler4] + configured_manager = CallbackManager.configure( + inheritable_callbacks=inheritable_callbacks, + local_callbacks=local_callbacks, + verbose=True, + ) + + assert len(configured_manager.handlers) == 5 + assert len(configured_manager.inheritable_handlers) == 2 + assert configured_manager.inheritable_handlers == inheritable_callbacks + assert configured_manager.handlers[:4] == inheritable_callbacks + local_callbacks + assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler) + assert isinstance(configured_manager, CallbackManager) + + async_local_callbacks = AsyncCallbackManager([handler3, handler4]) + async_configured_manager = AsyncCallbackManager.configure( + inheritable_callbacks=inheritable_callbacks, + local_callbacks=async_local_callbacks, + verbose=False, + ) + + assert len(async_configured_manager.handlers) == 4 + assert len(async_configured_manager.inheritable_handlers) == 2 + assert async_configured_manager.inheritable_handlers == inheritable_callbacks + assert async_configured_manager.handlers == inheritable_callbacks + [ + handler3, + handler4, + ] + assert isinstance(async_configured_manager, AsyncCallbackManager) diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index ab18d53e..66f0c387 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -1,9 +1,9 @@ """Test Tracer classes.""" from __future__ import annotations -import threading from datetime import datetime -from typing import List, Optional, Union +from typing import List, Union +from uuid import uuid4 import pytest from freezegun import freeze_time @@ -12,9 +12,7 @@ from langchain.callbacks.tracers.base import ( BaseTracer, ChainRun, LLMRun, - SharedTracer, ToolRun, - Tracer, TracerException, TracerSession, ) @@ -24,88 +22,6 @@ from langchain.schema import LLMResult TEST_SESSION_ID = 2023 -@freeze_time("2023-01-01") -def _get_compare_run() -> Union[LLMRun, ChainRun, ToolRun]: - return ChainRun( - id=None, - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=1, - serialized={}, - inputs={}, - outputs={}, - session_id=TEST_SESSION_ID, - child_runs=[ - ToolRun( - id=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=2, - serialized={}, - tool_input="test", - output="test", - action="{}", - session_id=TEST_SESSION_ID, - error=None, - child_runs=[ - LLMRun( - id=None, - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=3, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - ) - ], - ), - LLMRun( - id=None, - error=None, - start_time=datetime.utcnow(), - end_time=datetime.utcnow(), - extra={}, - execution_order=4, - serialized={}, - prompts=[], - response=LLMResult(generations=[[]]), - session_id=TEST_SESSION_ID, - ), - ], - ) - - -def _perform_nested_run(tracer: BaseTracer) -> None: - """Perform a nested run.""" - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_tool_end("test") - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_chain_end(outputs={}) - - -def _add_child_run( - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], -) -> None: - """Add child run to a chain run or tool run.""" - parent_run.child_runs.append(child_run) - - -def _generate_id() -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return None - - def load_session(session_name: str) -> TracerSession: """Load a tracing session.""" return TracerSession(id=1, name=session_name, start_time=datetime.utcnow()) @@ -121,7 +37,7 @@ def load_default_session() -> TracerSession: return TracerSession(id=1, name="default", start_time=datetime.utcnow()) -class FakeTracer(Tracer): +class FakeTracer(BaseTracer): """Fake tracer that records LangChain execution.""" def __init__(self) -> None: @@ -133,58 +49,6 @@ class FakeTracer(Tracer): """Persist a run.""" self.runs.append(run) - def _add_child_run( - self, - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], - ) -> None: - """Add child run to a chain run or tool run.""" - _add_child_run(parent_run, child_run) - - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return _generate_id() - - def _persist_session(self, session: TracerSessionCreate) -> TracerSession: - """Persist a tracing session.""" - return _persist_session(session) - - def load_session(self, session_name: str) -> TracerSession: - """Load a tracing session.""" - return load_session(session_name) - - def load_default_session(self) -> TracerSession: - """Load a tracing session.""" - return load_default_session() - - -class FakeSharedTracer(SharedTracer): - """Fake shared tracer that records LangChain execution.""" - - runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] - - def _persist_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> None: - """Persist a run.""" - with self._lock: - self.runs.append(run) - - def remove_runs(self) -> None: - """Remove all runs.""" - with self._lock: - self.runs = [] - - def _add_child_run( - self, - parent_run: Union[ChainRun, ToolRun], - child_run: Union[LLMRun, ChainRun, ToolRun], - ) -> None: - """Add child run to a chain run or tool run.""" - _add_child_run(parent_run, child_run) - - def _generate_id(self) -> Optional[Union[int, str]]: - """Generate an id for a run.""" - return _generate_id() - def _persist_session(self, session: TracerSessionCreate) -> TracerSession: """Persist a tracing session.""" return _persist_session(session) @@ -201,12 +65,15 @@ class FakeSharedTracer(SharedTracer): @freeze_time("2023-01-01") def test_tracer_llm_run() -> None: """Test tracer on an LLM run.""" + uuid = uuid4() compare_run = LLMRun( - id=None, + uuid=str(uuid), + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, prompts=[], response=LLMResult(generations=[[]]), @@ -216,20 +83,11 @@ def test_tracer_llm_run() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) + tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] -@freeze_time("2023-01-01") -def test_tracer_llm_run_errors_no_session() -> None: - """Test tracer on an LLM run without a session.""" - tracer = FakeTracer() - - with pytest.raises(TracerException): - tracer.on_llm_start(serialized={}, prompts=[]) - - @freeze_time("2023-01-01") def test_tracer_llm_run_errors_no_start() -> None: """Test tracer on an LLM run without a start.""" @@ -237,18 +95,21 @@ def test_tracer_llm_run_errors_no_start() -> None: tracer.new_session() with pytest.raises(TracerException): - tracer.on_llm_end(response=LLMResult(generations=[[]])) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4()) @freeze_time("2023-01-01") def test_tracer_multiple_llm_runs() -> None: """Test the tracer with multiple runs.""" + uuid = uuid4() compare_run = LLMRun( - id=None, + uuid=str(uuid), + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, prompts=[], response=LLMResult(generations=[[]]), @@ -260,8 +121,8 @@ def test_tracer_multiple_llm_runs() -> None: tracer.new_session() num_runs = 10 for _ in range(num_runs): - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) + tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] * num_runs @@ -269,12 +130,15 @@ def test_tracer_multiple_llm_runs() -> None: @freeze_time("2023-01-01") def test_tracer_chain_run() -> None: """Test tracer on a Chain run.""" + uuid = uuid4() compare_run = ChainRun( - id=None, + uuid=str(uuid), + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, inputs={}, outputs={}, @@ -284,20 +148,23 @@ def test_tracer_chain_run() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_chain_end(outputs={}) + tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid) + tracer.on_chain_end(outputs={}, run_id=uuid) assert tracer.runs == [compare_run] @freeze_time("2023-01-01") def test_tracer_tool_run() -> None: """Test tracer on a Tool run.""" + uuid = uuid4() compare_run = ToolRun( - id=None, + uuid=str(uuid), + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, tool_input="test", output="test", @@ -308,8 +175,8 @@ def test_tracer_tool_run() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_tool_end("test") + tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid) + tracer.on_tool_end("test", run_id=uuid) assert tracer.runs == [compare_run] @@ -318,21 +185,109 @@ def test_tracer_nested_run() -> None: """Test tracer on a nested run.""" tracer = FakeTracer() tracer.new_session() - _perform_nested_run(tracer) - assert tracer.runs == [_get_compare_run()] + + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + for _ in range(10): + tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) + tracer.on_tool_start( + serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid + ) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_tool_end("test", run_id=tool_uuid) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_chain_end(outputs={}, run_id=chain_uuid) + + compare_run = ChainRun( + uuid=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=4, + serialized={}, + inputs={}, + outputs={}, + session_id=TEST_SESSION_ID, + child_chain_runs=[], + child_tool_runs=[ + ToolRun( + uuid=str(tool_uuid), + parent_uuid=str(chain_uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=2, + child_execution_order=3, + serialized={}, + tool_input="test", + output="test", + action="{}", + session_id=TEST_SESSION_ID, + error=None, + child_chain_runs=[], + child_tool_runs=[], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid1), + parent_uuid=str(tool_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=3, + child_execution_order=3, + serialized={}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ) + ], + ), + ], + child_llm_runs=[ + LLMRun( + uuid=str(llm_uuid2), + parent_uuid=str(chain_uuid), + error=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=4, + child_execution_order=4, + serialized={}, + prompts=[], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + ), + ], + ) + assert tracer.runs == [compare_run] * 10 @freeze_time("2023-01-01") def test_tracer_llm_run_on_error() -> None: """Test tracer on an LLM run with an error.""" exception = Exception("test") + uuid = uuid4() compare_run = LLMRun( - id=None, + uuid=str(uuid), + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, prompts=[], response=None, @@ -342,8 +297,8 @@ def test_tracer_llm_run_on_error() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_error(exception) + tracer.on_llm_start(serialized={}, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -351,13 +306,16 @@ def test_tracer_llm_run_on_error() -> None: def test_tracer_chain_run_on_error() -> None: """Test tracer on a Chain run with an error.""" exception = Exception("test") + uuid = uuid4() compare_run = ChainRun( - id=None, + uuid=str(uuid), + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, inputs={}, outputs=None, @@ -367,8 +325,8 @@ def test_tracer_chain_run_on_error() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_chain_error(exception) + tracer.on_chain_start(serialized={}, inputs={}, run_id=uuid) + tracer.on_chain_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -376,13 +334,16 @@ def test_tracer_chain_run_on_error() -> None: def test_tracer_tool_run_on_error() -> None: """Test tracer on a Tool run with an error.""" exception = Exception("test") + uuid = uuid4() compare_run = ToolRun( - id=None, + uuid=str(uuid), + parent_uuid=None, start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=1, serialized={}, tool_input="test", output=None, @@ -393,8 +354,8 @@ def test_tracer_tool_run_on_error() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_tool_error(exception) + tracer.on_tool_start(serialized={}, input_str="test", run_id=uuid) + tracer.on_tool_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -405,37 +366,53 @@ def test_tracer_nested_runs_on_error() -> None: tracer = FakeTracer() tracer.new_session() + chain_uuid = uuid4() + tool_uuid = uuid4() + llm_uuid1 = uuid4() + llm_uuid2 = uuid4() + llm_uuid3 = uuid4() for _ in range(3): - tracer.on_chain_start(serialized={}, inputs={}) - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_end(response=LLMResult(generations=[[]])) - tracer.on_tool_start(serialized={}, input_str="test") - tracer.on_llm_start(serialized={}, prompts=[]) - tracer.on_llm_error(exception) - tracer.on_tool_error(exception) - tracer.on_chain_error(exception) + tracer.on_chain_start(serialized={}, inputs={}, run_id=chain_uuid) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid + ) + tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2) + tracer.on_tool_start( + serialized={}, input_str="test", run_id=tool_uuid, parent_run_id=chain_uuid + ) + tracer.on_llm_start( + serialized={}, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid + ) + tracer.on_llm_error(exception, run_id=llm_uuid3) + tracer.on_tool_error(exception, run_id=tool_uuid) + tracer.on_chain_error(exception, run_id=chain_uuid) compare_run = ChainRun( - id=None, + uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=1, + child_execution_order=5, serialized={}, session_id=TEST_SESSION_ID, error=repr(exception), inputs={}, outputs=None, - child_runs=[ + child_llm_runs=[ LLMRun( - id=None, + uuid=str(llm_uuid1), + parent_uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=2, + child_execution_order=2, serialized={}, session_id=TEST_SESSION_ID, error=None, @@ -443,36 +420,45 @@ def test_tracer_nested_runs_on_error() -> None: response=LLMResult(generations=[[]], llm_output=None), ), LLMRun( - id=None, + uuid=str(llm_uuid2), + parent_uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=3, + child_execution_order=3, serialized={}, session_id=TEST_SESSION_ID, error=None, prompts=[], response=LLMResult(generations=[[]], llm_output=None), ), + ], + child_chain_runs=[], + child_tool_runs=[ ToolRun( - id=None, + uuid=str(tool_uuid), + parent_uuid=str(chain_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=4, + child_execution_order=5, serialized={}, session_id=TEST_SESSION_ID, error=repr(exception), tool_input="test", output=None, action="{}", - child_runs=[ + child_llm_runs=[ LLMRun( - id=None, + uuid=str(llm_uuid3), + parent_uuid=str(tool_uuid), start_time=datetime.utcnow(), end_time=datetime.utcnow(), extra={}, execution_order=5, + child_execution_order=5, serialized={}, session_id=TEST_SESSION_ID, error=repr(exception), @@ -480,43 +466,10 @@ def test_tracer_nested_runs_on_error() -> None: response=None, ) ], - child_llm_runs=[], child_chain_runs=[], child_tool_runs=[], ), ], - child_llm_runs=[], - child_chain_runs=[], - child_tool_runs=[], ) assert tracer.runs == [compare_run] * 3 - - -@freeze_time("2023-01-01") -def test_shared_tracer_nested_run() -> None: - """Test shared tracer on a nested run.""" - tracer = FakeSharedTracer() - tracer.new_session() - tracer.remove_runs() - _perform_nested_run(tracer) - assert tracer.runs == [_get_compare_run()] - - -@freeze_time("2023-01-01") -def test_shared_tracer_nested_run_multithreaded() -> None: - """Test shared tracer on a nested run.""" - tracer = FakeSharedTracer() - tracer.remove_runs() - tracer.new_session() - threads = [] - num_threads = 10 - for _ in range(num_threads): - thread = threading.Thread(target=_perform_nested_run, args=(tracer,)) - thread.start() - threads.append(thread) - - for thread in threads: - thread.join() - - assert tracer.runs == [_get_compare_run()] * num_threads diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 0b0aebf7..1e5022b8 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional import pytest -from langchain.callbacks.base import CallbackManager +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.schema import BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -25,11 +25,9 @@ class FakeMemory(BaseMemory): def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: """Pass.""" - pass def clear(self) -> None: """Pass.""" - pass class FakeChain(Chain): @@ -49,7 +47,11 @@ class FakeChain(Chain): """Output key of bar.""" return self.the_output_keys - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: if self.be_correct: return {"bar": "baz"} else: @@ -143,25 +145,10 @@ def test_run_with_callback() -> None: """Test run method works when callback manager is passed.""" handler = FakeCallbackHandler() chain = FakeChain( - callback_manager=CallbackManager(handlers=[handler]), verbose=True + callbacks=[handler], ) output = chain.run("bar") assert output == "baz" assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 - - -def test_run_with_callback_not_verbose() -> None: - """Test run method works when callback manager is passed and not verbose.""" - import langchain - - langchain.verbose = False - - handler = FakeCallbackHandler() - chain = FakeChain(callback_manager=CallbackManager(handlers=[handler])) - output = chain.run("bar") - assert output == "baz" - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 diff --git a/tests/unit_tests/chains/test_hyde.py b/tests/unit_tests/chains/test_hyde.py index cc3e6ae4..dd2ade83 100644 --- a/tests/unit_tests/chains/test_hyde.py +++ b/tests/unit_tests/chains/test_hyde.py @@ -3,6 +3,10 @@ from typing import List, Optional import numpy as np +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.embeddings.base import Embeddings @@ -28,12 +32,18 @@ class FakeLLM(BaseLLM): n: int = 1 def _generate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) async def _agenerate( - self, prompts: List[str], stop: Optional[List[str]] = None + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) diff --git a/tests/unit_tests/chains/test_llm_bash.py b/tests/unit_tests/chains/test_llm_bash.py index 3e20e356..e6ee11d0 100644 --- a/tests/unit_tests/chains/test_llm_bash.py +++ b/tests/unit_tests/chains/test_llm_bash.py @@ -3,8 +3,8 @@ import sys import pytest -from langchain.chains.llm_bash.base import BashOutputParser, LLMBashChain -from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE +from langchain.chains.llm_bash.base import LLMBashChain +from langchain.chains.llm_bash.prompt import _PROMPT_TEMPLATE, BashOutputParser from langchain.schema import OutputParserException from tests.unit_tests.llms.fake_llm import FakeLLM @@ -43,7 +43,7 @@ def test_simple_question() -> None: prompt = _PROMPT_TEMPLATE.format(question=question) queries = {prompt: "```bash\nexpr 1 + 1\n```"} fake_llm = FakeLLM(queries=queries) - fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a") + fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") output = fake_llm_bash_chain.run(question) assert output == "2\n" @@ -71,7 +71,7 @@ echo 'hello world' """ } fake_llm = FakeLLM(queries=queries) - fake_llm_bash_chain = LLMBashChain(llm=fake_llm, input_key="q", output_key="a") + fake_llm_bash_chain = LLMBashChain.from_llm(fake_llm, input_key="q", output_key="a") with pytest.raises(OutputParserException): fake_llm_bash_chain.run(question) diff --git a/tests/unit_tests/chains/test_llm_checker.py b/tests/unit_tests/chains/test_llm_checker.py index 0c9b9343..cc2ceb99 100644 --- a/tests/unit_tests/chains/test_llm_checker.py +++ b/tests/unit_tests/chains/test_llm_checker.py @@ -33,7 +33,7 @@ def fake_llm_checker_chain() -> LLMCheckerChain: ): "I still don't know.", } fake_llm = FakeLLM(queries=queries) - return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a") + return LLMCheckerChain.from_llm(fake_llm, input_key="q", output_key="a") def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None: diff --git a/tests/unit_tests/chains/test_llm_math.py b/tests/unit_tests/chains/test_llm_math.py index c412436c..4e3887ab 100644 --- a/tests/unit_tests/chains/test_llm_math.py +++ b/tests/unit_tests/chains/test_llm_math.py @@ -17,7 +17,7 @@ def fake_llm_math_chain() -> LLMMathChain: _PROMPT_TEMPLATE.format(question="foo"): "foo", } fake_llm = FakeLLM(queries=queries) - return LLMMathChain(llm=fake_llm, input_key="q", output_key="a") + return LLMMathChain.from_llm(fake_llm, input_key="q", output_key="a") def test_simple_question(fake_llm_math_chain: LLMMathChain) -> None: diff --git a/tests/unit_tests/chains/test_llm_summarization_checker.py b/tests/unit_tests/chains/test_llm_summarization_checker.py index 81e4a8fa..aa82cead 100644 --- a/tests/unit_tests/chains/test_llm_summarization_checker.py +++ b/tests/unit_tests/chains/test_llm_summarization_checker.py @@ -32,7 +32,9 @@ def fake_llm_summarization_checker_chain() -> LLMSummarizationCheckerChain: ): "True", } fake_llm = FakeLLM(queries=queries) - return LLMSummarizationCheckerChain(llm=fake_llm, input_key="q", output_key="a") + return LLMSummarizationCheckerChain.from_llm( + fake_llm, input_key="q", output_key="a" + ) def test_simple_text( diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index fd30901a..77c29808 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -2,6 +2,7 @@ from typing import Any, List, Mapping, Optional +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chains.natbot.base import NatBotChain from langchain.llms.base import LLM @@ -9,7 +10,12 @@ from langchain.llms.base import LLM class FakeLLM(LLM): """Fake LLM wrapper for testing purposes.""" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: """Return `foo` if longer than 10000 words, else `bar`.""" if len(prompt) > 10000: return "foo" @@ -28,7 +34,7 @@ class FakeLLM(LLM): def test_proper_inputs() -> None: """Test that natbot shortens inputs correctly.""" - nat_bot_chain = NatBotChain(llm=FakeLLM(), objective="testing") + nat_bot_chain = NatBotChain.from_llm(FakeLLM(), objective="testing") url = "foo" * 10000 browser_content = "foo" * 10000 output = nat_bot_chain.execute(url, browser_content) @@ -37,8 +43,8 @@ def test_proper_inputs() -> None: def test_variable_key_naming() -> None: """Test that natbot handles variable key naming correctly.""" - nat_bot_chain = NatBotChain( - llm=FakeLLM(), + nat_bot_chain = NatBotChain.from_llm( + FakeLLM(), objective="testing", input_url_key="u", input_browser_content_key="b", diff --git a/tests/unit_tests/chains/test_sequential.py b/tests/unit_tests/chains/test_sequential.py index 2ef0e7d4..19e7df10 100644 --- a/tests/unit_tests/chains/test_sequential.py +++ b/tests/unit_tests/chains/test_sequential.py @@ -1,8 +1,9 @@ """Test pipeline functionality.""" -from typing import Dict, List +from typing import Dict, List, Optional import pytest +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.memory.simple import SimpleMemory @@ -24,7 +25,11 @@ class FakeChain(Chain): """Input keys this chain returns.""" return self.output_variables - def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + def _call( + self, + inputs: Dict[str, str], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: outputs = {} for var in self.output_variables: variables = [inputs[k] for k in self.input_variables] diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index cc12a7ca..8815cc0b 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -3,6 +3,7 @@ from typing import Any, List, Mapping, Optional, cast from pydantic import validator +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM @@ -28,7 +29,12 @@ class FakeLLM(LLM): """Return type of llm.""" return "fake" - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: if self.sequential_responses: return self._get_next_response_in_sequence diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index d9d52630..ce0cf77f 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -1,5 +1,4 @@ """Test LLM callbacks.""" -from langchain.callbacks.base import CallbackManager from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.llms.fake_llm import FakeLLM @@ -7,24 +6,9 @@ from tests.unit_tests.llms.fake_llm import FakeLLM def test_llm_with_callbacks() -> None: """Test LLM callbacks.""" handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler]), verbose=True) + llm = FakeLLM(callbacks=[handler], verbose=True) output = llm("foo") assert output == "foo" assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 - - -def test_llm_with_callbacks_not_verbose() -> None: - """Test LLM callbacks but not verbose.""" - import langchain - - langchain.verbose = False - - handler = FakeCallbackHandler() - llm = FakeLLM(callback_manager=CallbackManager(handlers=[handler])) - output = llm("foo") - assert output == "foo" - assert handler.starts == 0 - assert handler.ends == 0 - assert handler.errors == 0 diff --git a/tests/unit_tests/tools/test_signatures.py b/tests/unit_tests/tools/test_signatures.py new file mode 100644 index 00000000..b1634dfc --- /dev/null +++ b/tests/unit_tests/tools/test_signatures.py @@ -0,0 +1,45 @@ +"""Test base tool child implementations.""" + + +import inspect +import re +from typing import List, Type + +import pytest + +from langchain.tools.base import BaseTool +from langchain.tools.playwright.base import BaseBrowserTool + + +def get_non_abstract_subclasses(cls: Type[BaseTool]) -> List[Type[BaseTool]]: + to_skip = {BaseBrowserTool} # Abstract but not recognized + subclasses = [] + for subclass in cls.__subclasses__(): + if ( + not getattr(subclass, "__abstract__", None) + and not subclass.__name__.startswith("_") + and subclass not in to_skip + ): + subclasses.append(subclass) + subclasses.extend(get_non_abstract_subclasses(subclass)) + return subclasses + + +@pytest.mark.parametrize("cls", get_non_abstract_subclasses(BaseTool)) # type: ignore +def test_all_subclasses_accept_run_manager(cls: Type[BaseTool]) -> None: + """Test that tools defined in this repo accept a run manager argument.""" + # This wouldn't be necessary if the BaseTool had a strict API. + if cls._run is not BaseTool._arun: + run_func = cls._run + params = inspect.signature(run_func).parameters + assert "run_manager" in params + pattern = re.compile(r"(?!Async)CallbackManagerForToolRun") + assert bool(re.search(pattern, str(params["run_manager"].annotation))) + assert params["run_manager"].default is None + + if cls._arun is not BaseTool._arun: + run_func = cls._arun + params = inspect.signature(run_func).parameters + assert "run_manager" in params + assert "AsyncCallbackManagerForToolRun" in str(params["run_manager"].annotation) + assert params["run_manager"].default is None