mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
bdf0c2267f
Fix typo in the document of custom_chain
198 lines
6.5 KiB
Plaintext
198 lines
6.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "593f7553-7038-498e-96d4-8255e5ce34f0",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Custom chain\n",
|
|
"\n",
|
|
"To implement your own custom chain you can subclass `Chain` and implement the following methods:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
|
|
"metadata": {
|
|
"tags": [],
|
|
"vscode": {
|
|
"languageId": "python"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from __future__ import annotations\n",
|
|
"\n",
|
|
"from typing import Any, Dict, List, Optional\n",
|
|
"\n",
|
|
"from pydantic import Extra\n",
|
|
"\n",
|
|
"from langchain.schema import BaseLanguageModel\n",
|
|
"from langchain.callbacks.manager import (\n",
|
|
" AsyncCallbackManagerForChainRun,\n",
|
|
" CallbackManagerForChainRun,\n",
|
|
")\n",
|
|
"from langchain.chains.base import Chain\n",
|
|
"from langchain.prompts.base import BasePromptTemplate\n",
|
|
"\n",
|
|
"\n",
|
|
"class MyCustomChain(Chain):\n",
|
|
" \"\"\"\n",
|
|
" An example of a custom chain.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" prompt: BasePromptTemplate\n",
|
|
" \"\"\"Prompt object to use.\"\"\"\n",
|
|
" llm: BaseLanguageModel\n",
|
|
" output_key: str = \"text\" #: :meta private:\n",
|
|
"\n",
|
|
" class Config:\n",
|
|
" \"\"\"Configuration for this pydantic object.\"\"\"\n",
|
|
"\n",
|
|
" extra = Extra.forbid\n",
|
|
" arbitrary_types_allowed = True\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def input_keys(self) -> List[str]:\n",
|
|
" \"\"\"Will be whatever keys the prompt expects.\n",
|
|
"\n",
|
|
" :meta private:\n",
|
|
" \"\"\"\n",
|
|
" return self.prompt.input_variables\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def output_keys(self) -> List[str]:\n",
|
|
" \"\"\"Will always return text key.\n",
|
|
"\n",
|
|
" :meta private:\n",
|
|
" \"\"\"\n",
|
|
" return [self.output_key]\n",
|
|
"\n",
|
|
" def _call(\n",
|
|
" self,\n",
|
|
" inputs: Dict[str, Any],\n",
|
|
" run_manager: Optional[CallbackManagerForChainRun] = None,\n",
|
|
" ) -> Dict[str, str]:\n",
|
|
" # Your custom chain logic goes here\n",
|
|
" # This is just an example that mimics LLMChain\n",
|
|
" prompt_value = self.prompt.format_prompt(**inputs)\n",
|
|
"\n",
|
|
" # Whenever you call a language model, or another chain, you should pass\n",
|
|
" # a callback manager to it. This allows the inner run to be tracked by\n",
|
|
" # any callbacks that are registered on the outer run.\n",
|
|
" # You can always obtain a callback manager for this by calling\n",
|
|
" # `run_manager.get_child()` as shown below.\n",
|
|
" response = self.llm.generate_prompt(\n",
|
|
" [prompt_value], callbacks=run_manager.get_child() if run_manager else None\n",
|
|
" )\n",
|
|
"\n",
|
|
" # If you want to log something about this run, you can do so by calling\n",
|
|
" # methods on the `run_manager`, as shown below. This will trigger any\n",
|
|
" # callbacks that are registered for that event.\n",
|
|
" if run_manager:\n",
|
|
" run_manager.on_text(\"Log something about this run\")\n",
|
|
"\n",
|
|
" return {self.output_key: response.generations[0][0].text}\n",
|
|
"\n",
|
|
" async def _acall(\n",
|
|
" self,\n",
|
|
" inputs: Dict[str, Any],\n",
|
|
" run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
|
|
" ) -> Dict[str, str]:\n",
|
|
" # Your custom chain logic goes here\n",
|
|
" # This is just an example that mimics LLMChain\n",
|
|
" prompt_value = self.prompt.format_prompt(**inputs)\n",
|
|
"\n",
|
|
" # Whenever you call a language model, or another chain, you should pass\n",
|
|
" # a callback manager to it. This allows the inner run to be tracked by\n",
|
|
" # any callbacks that are registered on the outer run.\n",
|
|
" # You can always obtain a callback manager for this by calling\n",
|
|
" # `run_manager.get_child()` as shown below.\n",
|
|
" response = await self.llm.agenerate_prompt(\n",
|
|
" [prompt_value], callbacks=run_manager.get_child() if run_manager else None\n",
|
|
" )\n",
|
|
"\n",
|
|
" # If you want to log something about this run, you can do so by calling\n",
|
|
" # methods on the `run_manager`, as shown below. This will trigger any\n",
|
|
" # callbacks that are registered for that event.\n",
|
|
" if run_manager:\n",
|
|
" await run_manager.on_text(\"Log something about this run\")\n",
|
|
"\n",
|
|
" return {self.output_key: response.generations[0][0].text}\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def _chain_type(self) -> str:\n",
|
|
" return \"my_custom_chain\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "18361f89",
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "python"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
|
|
"Log something about this run\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from langchain.callbacks.stdout import StdOutCallbackHandler\n",
|
|
"from langchain.chat_models.openai import ChatOpenAI\n",
|
|
"from langchain.prompts.prompt import PromptTemplate\n",
|
|
"\n",
|
|
"\n",
|
|
"chain = MyCustomChain(\n",
|
|
" prompt=PromptTemplate.from_template(\"tell us a joke about {topic}\"),\n",
|
|
" llm=ChatOpenAI(),\n",
|
|
")\n",
|
|
"\n",
|
|
"chain.run({\"topic\": \"callbacks\"}, callbacks=[StdOutCallbackHandler()])"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|