mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
8d3b059332
Basically copy what's in the ts docs: https://js.langchain.com/docs/production/callbacks Discovered a bug wrt not awaiting callbacks in `LLMMathChain` so fixed that
389 lines
15 KiB
Plaintext
389 lines
15 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "23234b50-e6c6-4c87-9f97-259c15f36894",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"# Callbacks"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "29dd6333-307c-43df-b848-65001c01733b",
|
||
"metadata": {},
|
||
"source": [
|
||
"LangChain provides a callback system that allows you to hook into the various stages of your LLM application. This is useful for logging, [monitoring](https://python.langchain.com/en/latest/tracing.html), [streaming](https://python.langchain.com/en/latest/modules/models/llms/examples/streaming_llm.html), and other tasks.\n",
|
||
"\n",
|
||
"You can subscribe to these events by using the `callback_manager` argument available throughout the API. A `CallbackManager` is an object that manages a list of `CallbackHandlers`. The `CallbackManager` will call the appropriate method on each handler when the event is triggered."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fdb72e8d-a02a-474d-96bf-f5759432afc8",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"```python\n",
|
||
"class CallbackManager(BaseCallbackHandler):\n",
|
||
" \"\"\"Base callback manager that can be used to handle callbacks from LangChain.\"\"\"\n",
|
||
"\n",
|
||
" def add_handler(self, callback: BaseCallbackHandler) -> None:\n",
|
||
" \"\"\"Add a handler to the callback manager.\"\"\"\n",
|
||
"\n",
|
||
" def remove_handler(self, handler: BaseCallbackHandler) -> None:\n",
|
||
" \"\"\"Remove a handler from the callback manager.\"\"\"\n",
|
||
"\n",
|
||
" def set_handler(self, handler: BaseCallbackHandler) -> None:\n",
|
||
" \"\"\"Set handler as the only handler on the callback manager.\"\"\"\n",
|
||
" self.set_handlers([handler])\n",
|
||
"\n",
|
||
" def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:\n",
|
||
" \"\"\"Set handlers as the only handlers on the callback manager.\"\"\"\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "2b6d7dba-cd20-472a-ae05-f68675cc9ea4",
|
||
"metadata": {},
|
||
"source": [
|
||
"`CallbackHandlers` are objects that implement the `CallbackHandler` interface, which has a method for each event that can be subscribed to. The `CallbackManager` will call the appropriate method on each handler when the event is triggered."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e4592215-6604-47e2-89ff-5db3af6d1e40",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"```python\n",
|
||
"class BaseCallbackHandler(ABC):\n",
|
||
" \"\"\"Base callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_llm_start(\n",
|
||
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
|
||
" ) -> Any:\n",
|
||
" \"\"\"Run when LLM starts running.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:\n",
|
||
" \"\"\"Run on new LLM token. Only available when streaming is enabled.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:\n",
|
||
" \"\"\"Run when LLM ends running.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_llm_error(\n",
|
||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||
" ) -> Any:\n",
|
||
" \"\"\"Run when LLM errors.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_chain_start(\n",
|
||
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
|
||
" ) -> Any:\n",
|
||
" \"\"\"Run when chain starts running.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:\n",
|
||
" \"\"\"Run when chain ends running.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_chain_error(\n",
|
||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||
" ) -> Any:\n",
|
||
" \"\"\"Run when chain errors.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_tool_start(\n",
|
||
" self, serialized: Dict[str, Any], input_str: str, **kwargs: Any\n",
|
||
" ) -> Any:\n",
|
||
" \"\"\"Run when tool starts running.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_tool_end(self, output: str, **kwargs: Any) -> Any:\n",
|
||
" \"\"\"Run when tool ends running.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_tool_error(\n",
|
||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||
" ) -> Any:\n",
|
||
" \"\"\"Run when tool errors.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_text(self, text: str, **kwargs: Any) -> Any:\n",
|
||
" \"\"\"Run on arbitrary text.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:\n",
|
||
" \"\"\"Run on agent action.\"\"\"\n",
|
||
"\n",
|
||
" @abstractmethod\n",
|
||
" def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:\n",
|
||
" \"\"\"Run on agent end.\"\"\"\n",
|
||
"```"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d3bf3304-43fb-47ad-ae50-0637a17018a2",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Creating and Using a Custom `CallbackHandler`\n",
|
||
"\n",
|
||
"By default, a shared CallbackManager with the StdOutCallbackHandler will be used by models, chains, agents, and tools. However, you can pass in your own CallbackManager with a custom CallbackHandler:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "80532dfc-d687-4147-a0c9-1f90cc3e868c",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"AgentAction(tool='Search', tool_input=\"US Open men's final 2019 winner\", log=' I need to find out who won the US Open men\\'s final in 2019 and then calculate his age raised to the 0.334 power.\\nAction: Search\\nAction Input: \"US Open men\\'s final 2019 winner\"')\n",
|
||
"Rafael Nadal defeated Daniil Medvedev in the final, 7–5, 6–3, 5–7, 4–6, 6–4 to win the men's singles tennis title at the 2019 US Open. It was his fourth US ...\n",
|
||
"AgentAction(tool='Search', tool_input='Rafael Nadal age', log=' I need to find out the age of the winner\\nAction: Search\\nAction Input: \"Rafael Nadal age\"')\n",
|
||
"36 years\n",
|
||
"AgentAction(tool='Calculator', tool_input='36^0.334', log=' I now need to calculate his age raised to the 0.334 power\\nAction: Calculator\\nAction Input: 36^0.334')\n",
|
||
"Answer: 3.3098250249682484\n",
|
||
"\n",
|
||
" I now know the final answer\n",
|
||
"Final Answer: Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"\"Rafael Nadal, aged 36, won the US Open men's final in 2019 and his age raised to the 0.334 power is 3.3098250249682484.\""
|
||
]
|
||
},
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"from typing import Any, Dict, List, Optional, Union\n",
|
||
"\n",
|
||
"from langchain.agents import initialize_agent, load_tools\n",
|
||
"from langchain.agents import AgentType\n",
|
||
"from langchain.callbacks.base import CallbackManager, BaseCallbackHandler\n",
|
||
"from langchain.llms import OpenAI\n",
|
||
"from langchain.schema import AgentAction, AgentFinish, LLMResult\n",
|
||
"\n",
|
||
"class MyCustomCallbackHandler(BaseCallbackHandler):\n",
|
||
" \"\"\"Custom CallbackHandler.\"\"\"\n",
|
||
"\n",
|
||
" def on_llm_start(\n",
|
||
" self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Print out the prompts.\"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n",
|
||
" \"\"\"Do nothing.\"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def on_llm_new_token(self, token: str, **kwargs: Any) -> None:\n",
|
||
" \"\"\"Do nothing.\"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def on_llm_error(\n",
|
||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Do nothing.\"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def on_chain_start(\n",
|
||
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Print out that we are entering a chain.\"\"\"\n",
|
||
" class_name = serialized[\"name\"]\n",
|
||
" print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n",
|
||
"\n",
|
||
" def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n",
|
||
" \"\"\"Print out that we finished a chain.\"\"\"\n",
|
||
" print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n",
|
||
"\n",
|
||
" def on_chain_error(\n",
|
||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Do nothing.\"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def on_tool_start(\n",
|
||
" self,\n",
|
||
" serialized: Dict[str, Any],\n",
|
||
" input_str: str,\n",
|
||
" **kwargs: Any,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Do nothing.\"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def on_agent_action(\n",
|
||
" self, action: AgentAction, color: Optional[str] = None, **kwargs: Any\n",
|
||
" ) -> Any:\n",
|
||
" \"\"\"Run on agent action.\"\"\"\n",
|
||
" print(action)\n",
|
||
"\n",
|
||
" def on_tool_end(\n",
|
||
" self,\n",
|
||
" output: str,\n",
|
||
" color: Optional[str] = None,\n",
|
||
" observation_prefix: Optional[str] = None,\n",
|
||
" llm_prefix: Optional[str] = None,\n",
|
||
" **kwargs: Any,\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"If not the final action, print out observation.\"\"\"\n",
|
||
" print(output)\n",
|
||
"\n",
|
||
" def on_tool_error(\n",
|
||
" self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Do nothing.\"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def on_text(\n",
|
||
" self,\n",
|
||
" text: str,\n",
|
||
" color: Optional[str] = None,\n",
|
||
" end: str = \"\",\n",
|
||
" **kwargs: Optional[str],\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Run when agent ends.\"\"\"\n",
|
||
" print(text)\n",
|
||
"\n",
|
||
" def on_agent_finish(\n",
|
||
" self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Run on agent end.\"\"\"\n",
|
||
" print(finish.log)\n",
|
||
"manager = CallbackManager([MyCustomCallbackHandler()])\n",
|
||
"llm = OpenAI(temperature=0, callback_manager=manager, verbose=True)\n",
|
||
"tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, callback_manager=manager)\n",
|
||
"agent = initialize_agent(\n",
|
||
" tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager\n",
|
||
")\n",
|
||
"agent.run(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bc9785fa-4f71-4797-91a3-4fe7e57d0429",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"source": [
|
||
"## Async Support\n",
|
||
"\n",
|
||
"If you are planning to use the async API, it is recommended to use `AsyncCallbackHandler` and `AsyncCallbackManager` to avoid blocking the runloop."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "c702e0c9-a961-4897-90c1-cdd13b6f16b2",
|
||
"metadata": {
|
||
"tags": []
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"zzzz....\n",
|
||
"\n",
|
||
"\n",
|
||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||
"zzzz....\n",
|
||
"\n",
|
||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import asyncio\n",
|
||
"from aiohttp import ClientSession\n",
|
||
"\n",
|
||
"from langchain.callbacks.base import AsyncCallbackHandler, AsyncCallbackManager\n",
|
||
"\n",
|
||
"class MyCustomAsyncCallbackHandler(AsyncCallbackHandler):\n",
|
||
" \"\"\"Async callback handler that can be used to handle callbacks from langchain.\"\"\"\n",
|
||
"\n",
|
||
" async def on_chain_start(\n",
|
||
" self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any\n",
|
||
" ) -> None:\n",
|
||
" \"\"\"Run when chain starts running.\"\"\"\n",
|
||
" print(\"zzzz....\")\n",
|
||
" await asyncio.sleep(0.5)\n",
|
||
" class_name = serialized[\"name\"]\n",
|
||
" print(f\"\\n\\n\\033[1m> Entering new {class_name} chain...\\033[0m\")\n",
|
||
"\n",
|
||
" async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:\n",
|
||
" \"\"\"Run when chain ends running.\"\"\"\n",
|
||
" print(\"zzzz....\")\n",
|
||
" await asyncio.sleep(0.5)\n",
|
||
" print(\"\\n\\033[1m> Finished chain.\\033[0m\")\n",
|
||
"\n",
|
||
"manager = AsyncCallbackManager([MyCustomAsyncCallbackHandler()])\n",
|
||
"\n",
|
||
"# To make async requests in Tools more efficient, you can pass in your own aiohttp.ClientSession, \n",
|
||
"# but you must manually close the client session at the end of your program/event loop\n",
|
||
"aiosession = ClientSession()\n",
|
||
"llm = OpenAI(temperature=0, callback_manager=manager)\n",
|
||
"async_tools = load_tools([\"llm-math\", \"serpapi\"], llm=llm, aiosession=aiosession, callback_manager=manager)\n",
|
||
"async_agent = initialize_agent(async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, callback_manager=manager)\n",
|
||
"await async_agent.arun(\"Who won the US Open men's final in 2019? What is his age raised to the 0.334 power?\")\n",
|
||
"await aiosession.close()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "86be6304-e433-4048-880c-a92a73244407",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.9"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|