You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/docs/modules/chains/generic/custom_chain.ipynb

200 lines
6.5 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "593f7553-7038-498e-96d4-8255e5ce34f0",
"metadata": {},
"source": [
"# Creating a custom Chain\n",
"\n",
"To implement your own custom chain you can subclass `Chain` and implement the following methods:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
"metadata": {
"tags": [],
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"from typing import Any, Dict, List, Optional\n",
"\n",
"from pydantic import Extra\n",
"\n",
"from langchain.base_language import BaseLanguageModel\n",
"from langchain.callbacks.manager import (\n",
" AsyncCallbackManagerForChainRun,\n",
" CallbackManagerForChainRun,\n",
")\n",
"from langchain.chains.base import Chain\n",
"from langchain.prompts.base import BasePromptTemplate\n",
"\n",
"\n",
"class MyCustomChain(Chain):\n",
" \"\"\"\n",
" An example of a custom chain.\n",
" \"\"\"\n",
"\n",
" prompt: BasePromptTemplate\n",
" \"\"\"Prompt object to use.\"\"\"\n",
" llm: BaseLanguageModel\n",
" output_key: str = \"text\" #: :meta private:\n",
"\n",
" class Config:\n",
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
"\n",
" extra = Extra.forbid\n",
" arbitrary_types_allowed = True\n",
"\n",
" @property\n",
" def input_keys(self) -> List[str]:\n",
" \"\"\"Will be whatever keys the prompt expects.\n",
"\n",
" :meta private:\n",
" \"\"\"\n",
" return self.prompt.input_variables\n",
"\n",
" @property\n",
" def output_keys(self) -> List[str]:\n",
" \"\"\"Will always return text key.\n",
"\n",
" :meta private:\n",
" \"\"\"\n",
" return [self.output_key]\n",
"\n",
" def _call(\n",
" self,\n",
" inputs: Dict[str, Any],\n",
" run_manager: Optional[CallbackManagerForChainRun] = None,\n",
" ) -> Dict[str, str]:\n",
" # Your custom chain logic goes here\n",
" # This is just an example that mimics LLMChain\n",
" prompt_value = self.prompt.format_prompt(**inputs)\n",
" \n",
" # Whenever you call a language model, or another chain, you should pass\n",
" # a callback manager to it. This allows the inner run to be tracked by\n",
" # any callbacks that are registered on the outer run.\n",
" # You can always obtain a callback manager for this by calling\n",
" # `run_manager.get_child()` as shown below.\n",
" response = self.llm.generate_prompt(\n",
" [prompt_value],\n",
" callbacks=run_manager.get_child() if run_manager else None\n",
" )\n",
"\n",
" # If you want to log something about this run, you can do so by calling\n",
" # methods on the `run_manager`, as shown below. This will trigger any\n",
" # callbacks that are registered for that event.\n",
" if run_manager:\n",
" run_manager.on_text(\"Log something about this run\")\n",
" \n",
" return {self.output_key: response.generations[0][0].text}\n",
"\n",
" async def _acall(\n",
" self,\n",
" inputs: Dict[str, Any],\n",
" run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
" ) -> Dict[str, str]:\n",
" # Your custom chain logic goes here\n",
" # This is just an example that mimics LLMChain\n",
" prompt_value = self.prompt.format_prompt(**inputs)\n",
" \n",
" # Whenever you call a language model, or another chain, you should pass\n",
" # a callback manager to it. This allows the inner run to be tracked by\n",
" # any callbacks that are registered on the outer run.\n",
" # You can always obtain a callback manager for this by calling\n",
" # `run_manager.get_child()` as shown below.\n",
" response = await self.llm.agenerate_prompt(\n",
" [prompt_value],\n",
" callbacks=run_manager.get_child() if run_manager else None\n",
" )\n",
"\n",
" # If you want to log something about this run, you can do so by calling\n",
" # methods on the `run_manager`, as shown below. This will trigger any\n",
" # callbacks that are registered for that event.\n",
" if run_manager:\n",
" await run_manager.on_text(\"Log something about this run\")\n",
" \n",
" return {self.output_key: response.generations[0][0].text}\n",
"\n",
" @property\n",
" def _chain_type(self) -> str:\n",
" return \"my_custom_chain\"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "18361f89",
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
"Log something about this run\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
"from langchain.chat_models.openai import ChatOpenAI\n",
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"\n",
"chain = MyCustomChain(\n",
" prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n",
" llm=ChatOpenAI()\n",
")\n",
"\n",
"chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}