{ "cells": [ { "cell_type": "markdown", "id": "593f7553-7038-498e-96d4-8255e5ce34f0", "metadata": {}, "source": [ "# Custom chain\n", "\n", "To implement your own custom chain you can subclass `Chain` and implement the following methods:" ] }, { "cell_type": "code", "execution_count": 11, "id": "c19c736e-ca74-4726-bb77-0a849bcc2960", "metadata": { "tags": [], "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "from typing import Any, Dict, List, Optional\n", "\n", "from pydantic import Extra\n", "\n", "from langchain.base_language import BaseLanguageModel\n", "from langchain.callbacks.manager import (\n", " AsyncCallbackManagerForChainRun,\n", " CallbackManagerForChainRun,\n", ")\n", "from langchain.chains.base import Chain\n", "from langchain.prompts.base import BasePromptTemplate\n", "\n", "\n", "class MyCustomChain(Chain):\n", " \"\"\"\n", " An example of a custom chain.\n", " \"\"\"\n", "\n", " prompt: BasePromptTemplate\n", " \"\"\"Prompt object to use.\"\"\"\n", " llm: BaseLanguageModel\n", " output_key: str = \"text\" #: :meta private:\n", "\n", " class Config:\n", " \"\"\"Configuration for this pydantic object.\"\"\"\n", "\n", " extra = Extra.forbid\n", " arbitrary_types_allowed = True\n", "\n", " @property\n", " def input_keys(self) -> List[str]:\n", " \"\"\"Will be whatever keys the prompt expects.\n", "\n", " :meta private:\n", " \"\"\"\n", " return self.prompt.input_variables\n", "\n", " @property\n", " def output_keys(self) -> List[str]:\n", " \"\"\"Will always return text key.\n", "\n", " :meta private:\n", " \"\"\"\n", " return [self.output_key]\n", "\n", " def _call(\n", " self,\n", " inputs: Dict[str, Any],\n", " run_manager: Optional[CallbackManagerForChainRun] = None,\n", " ) -> Dict[str, str]:\n", " # Your custom chain logic goes here\n", " # This is just an example that mimics LLMChain\n", " prompt_value = self.prompt.format_prompt(**inputs)\n", "\n", " # Whenever you call a language model, or another chain, you should pass\n", " # a callback manager to it. This allows the inner run to be tracked by\n", " # any callbacks that are registered on the outer run.\n", " # You can always obtain a callback manager for this by calling\n", " # `run_manager.get_child()` as shown below.\n", " response = self.llm.generate_prompt(\n", " [prompt_value], callbacks=run_manager.get_child() if run_manager else None\n", " )\n", "\n", " # If you want to log something about this run, you can do so by calling\n", " # methods on the `run_manager`, as shown below. This will trigger any\n", " # callbacks that are registered for that event.\n", " if run_manager:\n", " run_manager.on_text(\"Log something about this run\")\n", "\n", " return {self.output_key: response.generations[0][0].text}\n", "\n", " async def _acall(\n", " self,\n", " inputs: Dict[str, Any],\n", " run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n", " ) -> Dict[str, str]:\n", " # Your custom chain logic goes here\n", " # This is just an example that mimics LLMChain\n", " prompt_value = self.prompt.format_prompt(**inputs)\n", "\n", " # Whenever you call a language model, or another chain, you should pass\n", " # a callback manager to it. This allows the inner run to be tracked by\n", " # any callbacks that are registered on the outer run.\n", " # You can always obtain a callback manager for this by calling\n", " # `run_manager.get_child()` as shown below.\n", " response = await self.llm.agenerate_prompt(\n", " [prompt_value], callbacks=run_manager.get_child() if run_manager else None\n", " )\n", "\n", " # If you want to log something about this run, you can do so by calling\n", " # methods on the `run_manager`, as shown below. This will trigger any\n", " # callbacks that are registered for that event.\n", " if run_manager:\n", " await run_manager.on_text(\"Log something about this run\")\n", "\n", " return {self.output_key: response.generations[0][0].text}\n", "\n", " @property\n", " def _chain_type(self) -> str:\n", " return \"my_custom_chain\"" ] }, { "cell_type": "code", "execution_count": 12, "id": "18361f89", "metadata": { "vscode": { "languageId": "python" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n", "Log something about this run\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ "'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from langchain.callbacks.stdout import StdOutCallbackHandler\n", "from langchain.chat_models.openai import ChatOpenAI\n", "from langchain.prompts.prompt import PromptTemplate\n", "\n", "\n", "chain = MyCustomChain(\n", " prompt=PromptTemplate.from_template(\"tell us a joke about {topic}\"),\n", " llm=ChatOpenAI(),\n", ")\n", "\n", "chain.run({\"topic\": \"callbacks\"}, callbacks=[StdOutCallbackHandler()])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 5 }