Docs: Add custom chat model documenation (#17595)

This PR adds documentation about how to implement a custom chat model.
pull/17012/head^2
Eugene Yurtsev 7 months ago committed by GitHub
parent 07ee41d284
commit 865cabff05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,644 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "e3da9a3f-f583-4ba6-994e-0e8c1158f5eb",
"metadata": {},
"source": [
"# Custom Chat Model\n",
"\n",
"In this guide, we'll learn how to create a custom chat model using LangChain abstractions.\n",
"\n",
"Wrapping your LLM with the standard `ChatModel` interface allow you to use your LLM in existing LangChain programs with minimal code modifications!\n",
"\n",
"As an bonus, your LLM will automatically become a LangChain `Runnable` and will benefit from some optimizations out of the box (e.g., batch via a threadpool), async support, the `astream_events` API, etc.\n",
"\n",
"## Inputs and outputs\n",
"\n",
"First, we need to talk about messages which are the inputs and outputs of chat models.\n",
"\n",
"### Messages\n",
"\n",
"Chat models take messages as inputs and return a message as output. \n",
"\n",
"LangChain has a few built-in message types:\n",
"\n",
"- `SystemMessage`: Used for priming AI behavior, usually passed in as the first of a sequence of input messages.\n",
"- `HumanMessage`: Represents a message from a person interacting with the chat model.\n",
"- `AIMessage`: Represents a message from the chat model. This can be either text or a request to invoke a tool.\n",
"- `FunctionMessage` / `ToolMessage`: Message for passing the results of tool invocation back to the model.\n",
"\n",
"::: {.callout-note}\n",
"`ToolMessage` and `FunctionMessage` closely follow OpenAIs `function` and `tool` arguments.\n",
"\n",
"This is a rapidly developing field and as more models add function calling capabilities, expect that there will be additions to this schema.\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5046e6a-8b09-4a99-b6e6-7a605aac5738",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.messages import (\n",
" AIMessage,\n",
" BaseMessage,\n",
" FunctionMessage,\n",
" HumanMessage,\n",
" SystemMessage,\n",
" ToolMessage,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "53033447-8260-4f53-bd6f-b2f744e04e75",
"metadata": {},
"source": [
"### Streaming Variant\n",
"\n",
"All the chat messages have a streaming variant that contains `Chunk` in the name."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d4656e9d-bfa1-4703-8f79-762fe6421294",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.messages import (\n",
" AIMessageChunk,\n",
" FunctionMessageChunk,\n",
" HumanMessageChunk,\n",
" SystemMessageChunk,\n",
" ToolMessageChunk,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "81ebf3f4-c760-4898-b921-fdb469453d4a",
"metadata": {},
"source": [
"These chunks are used when streaming output from chat models, and they all define an additive property!"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9c15c299-6f8a-49cf-a072-09924fd44396",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessageChunk(content='Hello World!')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"AIMessageChunk(content=\"Hello\") + AIMessageChunk(content=\" World!\")"
]
},
{
"cell_type": "markdown",
"id": "8e952d64-6d38-4a2b-b996-8812c204a12c",
"metadata": {},
"source": [
"## Simple Chat Model\n",
"\n",
"Inherting from `SimpleChatModel` is great for prototyping!\n",
"\n",
"It won't allow you to implement all features that you might want out of a chat model, but it's quick to implement, and if you need more you can transition to `BaseChatModel` shown below.\n",
"\n",
"Let's implement a chat model that echoes back the last `n` characters of the prompt!\n",
"\n",
"You need to implement the following:\n",
"\n",
"* The method `_call` - Use to generate a chat result from a prompt.\n",
"\n",
"In addition, you have the option to specify the following:\n",
"\n",
"* The property `_identifying_params` - Represent model parameterization for logging purposes.\n",
"\n",
"Optional:\n",
"\n",
"* `_stream` - Use to implement streaming.\n"
]
},
{
"cell_type": "markdown",
"id": "bbfebea1",
"metadata": {},
"source": [
"## Base Chat Model\n",
"\n",
"Let's implement a chat model that echoes back the first `n` characetrs of the last message in the prompt!\n",
"\n",
"To do so, we will inherit from `BaseChatModel` and we'll need to implement the following methods/properties:\n",
"\n",
"In addition, you have the option to specify the following:\n",
"\n",
"To do so inherit from `BaseChatModel` which is a lower level class and implement the methods:\n",
"\n",
"* `_generate` - Use to generate a chat result from a prompt\n",
"* The property `_llm_type` - Used to uniquely identify the type of the model. Used for logging.\n",
"\n",
"Optional:\n",
"\n",
"* `_stream` - Use to implement streaming.\n",
"* `_agenerate` - Use to implement a native async method.\n",
"* `_astream` - Use to implement async version of `_stream`.\n",
"* The property `_identifying_params` - Represent model parameterization for logging purposes.\n",
"\n",
"\n",
":::{.callout-caution}\n",
"\n",
"Currently, to get async streaming to work (via `astream`), you must provide an implementation of `_astream`.\n",
"\n",
"By default if `_astream` is not provided, then async streaming falls back on `_agenerate` which does not support\n",
"token by token streaming.\n",
":::"
]
},
{
"cell_type": "markdown",
"id": "8e7047bd-c235-46f6-85e1-d6d7e0868eb1",
"metadata": {},
"source": [
"### Implementation"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "25ba32e5-5a6d-49f4-bb68-911827b84d61",
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, AsyncIterator, Dict, Iterator, List, Optional\n",
"\n",
"from langchain_core.callbacks import (\n",
" AsyncCallbackManagerForLLMRun,\n",
" CallbackManagerForLLMRun,\n",
")\n",
"from langchain_core.language_models import BaseChatModel, SimpleChatModel\n",
"from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage\n",
"from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult\n",
"from langchain_core.runnables import run_in_executor\n",
"\n",
"\n",
"class CustomChatModelAdvanced(BaseChatModel):\n",
" \"\"\"A custom chat model that echoes the first `n` characters of the input.\n",
"\n",
" When contributing an implementation to LangChain, carefully document\n",
" the model including the initialization parameters, include\n",
" an example of how to initialize the model and include any relevant\n",
" links to the underlying models documentation or API.\n",
"\n",
" Example:\n",
"\n",
" .. code-block:: python\n",
"\n",
" model = CustomChatModel(n=2)\n",
" result = model.invoke([HumanMessage(content=\"hello\")])\n",
" result = model.batch([[HumanMessage(content=\"hello\")],\n",
" [HumanMessage(content=\"world\")]])\n",
" \"\"\"\n",
"\n",
" n: int\n",
" \"\"\"The number of characters from the last message of the prompt to be echoed.\"\"\"\n",
"\n",
" def _generate(\n",
" self,\n",
" messages: List[BaseMessage],\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ) -> ChatResult:\n",
" \"\"\"Override the _generate method to implement the chat model logic.\n",
"\n",
" This can be a call to an API, a call to a local model, or any other\n",
" implementation that generates a response to the input prompt.\n",
"\n",
" Args:\n",
" messages: the prompt composed of a list of messages.\n",
" stop: a list of strings on which the model should stop generating.\n",
" If generation stops due to a stop token, the stop token itself\n",
" SHOULD BE INCLUDED as part of the output. This is not enforced\n",
" across models right now, but it's a good practice to follow since\n",
" it makes it much easier to parse the output of the model\n",
" downstream and understand why generation stopped.\n",
" run_manager: A run manager with callbacks for the LLM.\n",
" \"\"\"\n",
" last_message = messages[-1]\n",
" tokens = last_message.content[: self.n]\n",
" message = AIMessage(content=tokens)\n",
" generation = ChatGeneration(message=message)\n",
" return ChatResult(generations=[generation])\n",
"\n",
" def _stream(\n",
" self,\n",
" messages: List[BaseMessage],\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ) -> Iterator[ChatGenerationChunk]:\n",
" \"\"\"Stream the output of the model.\n",
"\n",
" This method should be implemented if the model can generate output\n",
" in a streaming fashion. If the model does not support streaming,\n",
" do not implement it. In that case streaming requests will be automatically\n",
" handled by the _generate method.\n",
"\n",
" Args:\n",
" messages: the prompt composed of a list of messages.\n",
" stop: a list of strings on which the model should stop generating.\n",
" If generation stops due to a stop token, the stop token itself\n",
" SHOULD BE INCLUDED as part of the output. This is not enforced\n",
" across models right now, but it's a good practice to follow since\n",
" it makes it much easier to parse the output of the model\n",
" downstream and understand why generation stopped.\n",
" run_manager: A run manager with callbacks for the LLM.\n",
" \"\"\"\n",
" last_message = messages[-1]\n",
" tokens = last_message.content[: self.n]\n",
"\n",
" for token in tokens:\n",
" chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))\n",
"\n",
" if run_manager:\n",
" run_manager.on_llm_new_token(token, chunk=chunk)\n",
"\n",
" yield chunk\n",
"\n",
" async def _astream(\n",
" self,\n",
" messages: List[BaseMessage],\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ) -> AsyncIterator[ChatGenerationChunk]:\n",
" \"\"\"An async variant of astream.\n",
"\n",
" If not provided, the default behavior is to delegate to the _generate method.\n",
"\n",
" The implementation below instead will delegate to `_stream` and will\n",
" kick it off in a separate thread.\n",
"\n",
" If you're able to natively support async, then by all means do so!\n",
" \"\"\"\n",
" result = await run_in_executor(\n",
" None,\n",
" self._stream,\n",
" messages,\n",
" stop=stop,\n",
" run_manager=run_manager.get_sync() if run_manager else None,\n",
" **kwargs,\n",
" )\n",
" for chunk in result:\n",
" yield chunk\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" \"\"\"Get the type of language model used by this chat model.\"\"\"\n",
" return \"echoing-chat-model-advanced\"\n",
"\n",
" @property\n",
" def _identifying_params(self) -> Dict[str, Any]:\n",
" \"\"\"Return a dictionary of identifying parameters.\"\"\"\n",
" return {\"n\": self.n}"
]
},
{
"cell_type": "markdown",
"id": "b3c3d030-8d8b-4891-962d-a2d39b331883",
"metadata": {},
"source": [
":::{.callout-tip}\n",
"The `_astream` implementation uses `run_in_executor` to launch the sync `_stream` in a separate thread.\n",
"\n",
"You can use this trick if you want to reuse the `_stream` implementation, but if you're able to implement code\n",
"that's natively async that's a better solution since that code will run with less overhead.\n",
":::"
]
},
{
"cell_type": "markdown",
"id": "1e9af284-f2d3-44e2-ac6a-09b73d89ada3",
"metadata": {},
"source": [
"### Let's test it 🧪\n",
"\n",
"The chat model will implement the standard `Runnable` interface of LangChain which many of the LangChain abstractions support!"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "34bf2d48-556a-48be-aee7-496fb02332f3",
"metadata": {},
"outputs": [],
"source": [
"model = CustomChatModelAdvanced(n=3)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "27689f30-dcd2-466b-ba9d-f60b7d434110",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Meo')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.invoke(\n",
" [\n",
" HumanMessage(content=\"hello!\"),\n",
" AIMessage(content=\"Hi there human!\"),\n",
" HumanMessage(content=\"Meow!\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "406436df-31bf-466b-9c3d-39db9d6b6407",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='hel')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.invoke(\"hello\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a72ffa46-6004-41ef-bbe4-56fa17a029e2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content='hel'), AIMessage(content='goo')]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.batch([\"hello\", \"goodbye\"])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3633be2c-2ea0-42f9-a72f-3b5240690b55",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"c|a|t|"
]
}
],
"source": [
"for chunk in model.stream(\"cat\"):\n",
" print(chunk.content, end=\"|\")"
]
},
{
"cell_type": "markdown",
"id": "3f8a7c42-aec4-4116-adf3-93133d409827",
"metadata": {},
"source": [
"Please see the implementation of `_astream` in the model! If you do not implement it, then no output will stream.!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "b7d73995-eeab-48c6-a7d8-32c98ba29fc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"c|a|t|"
]
}
],
"source": [
"async for chunk in model.astream(\"cat\"):\n",
" print(chunk.content, end=\"|\")"
]
},
{
"cell_type": "markdown",
"id": "f80dc55b-d159-4527-9191-407a7c6d6042",
"metadata": {},
"source": [
"Let's try to use the astream events API which will also help double check that all the callbacks were implemented!"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "17840eba-8ff4-4e73-8e4f-85f16eb1c9d0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'event': 'on_chat_model_start', 'run_id': 'e03c0b21-521f-4cb4-a837-02fed65cf1cf', 'name': 'CustomChatModelAdvanced', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}}\n",
"{'event': 'on_chat_model_stream', 'run_id': 'e03c0b21-521f-4cb4-a837-02fed65cf1cf', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='c')}}\n",
"{'event': 'on_chat_model_stream', 'run_id': 'e03c0b21-521f-4cb4-a837-02fed65cf1cf', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='a')}}\n",
"{'event': 'on_chat_model_stream', 'run_id': 'e03c0b21-521f-4cb4-a837-02fed65cf1cf', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='t')}}\n",
"{'event': 'on_chat_model_end', 'name': 'CustomChatModelAdvanced', 'run_id': 'e03c0b21-521f-4cb4-a837-02fed65cf1cf', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat')}}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/eugene/src/langchain/libs/core/langchain_core/_api/beta_decorator.py:86: LangChainBetaWarning: This API is in beta and may change in the future.\n",
" warn_beta(\n"
]
}
],
"source": [
"async for event in model.astream_events(\"cat\", version=\"v1\"):\n",
" print(event)"
]
},
{
"cell_type": "markdown",
"id": "42f9553f-7d8c-4277-aeb4-d80d77839d90",
"metadata": {},
"source": [
"## Identifying Params\n",
"\n",
"LangChain has a callback system which allows implementing loggers to monitor the behavior of LLM applications.\n",
"\n",
"Remember the `_identifying_params` property from earlier? \n",
"\n",
"It's passed to the callback system and is accessible for user specified loggers.\n",
"\n",
"Below we'll implement a handler with just a single `on_chat_model_start` event to see where `_identifying_params` appears."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cc7e6b5f-711b-48aa-9ebe-92a13e230c37",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---\n",
"On chat model start.\n",
"{'invocation_params': {'n': 3, '_type': 'echoing-chat-model-advanced', 'stop': ['woof']}, 'options': {'stop': ['woof']}, 'name': None, 'batch_size': 1}\n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content='meo')"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from typing import Union\n",
"from uuid import UUID\n",
"\n",
"from langchain_core.callbacks import AsyncCallbackHandler\n",
"from langchain_core.outputs import (\n",
" ChatGenerationChunk,\n",
" ChatResult,\n",
" GenerationChunk,\n",
" LLMResult,\n",
")\n",
"\n",
"\n",
"class SampleCallbackHandler(AsyncCallbackHandler):\n",
" \"\"\"Async callback handler that handles callbacks from LangChain.\"\"\"\n",
"\n",
" async def on_chat_model_start(\n",
" self,\n",
" serialized: Dict[str, Any],\n",
" messages: List[List[BaseMessage]],\n",
" *,\n",
" run_id: UUID,\n",
" parent_run_id: Optional[UUID] = None,\n",
" tags: Optional[List[str]] = None,\n",
" metadata: Optional[Dict[str, Any]] = None,\n",
" **kwargs: Any,\n",
" ) -> Any:\n",
" \"\"\"Run when a chat model starts running.\"\"\"\n",
" print(\"---\")\n",
" print(\"On chat model start.\")\n",
" print(kwargs)\n",
"\n",
"\n",
"model.invoke(\"meow\", stop=[\"woof\"], config={\"callbacks\": [SampleCallbackHandler()]})"
]
},
{
"cell_type": "markdown",
"id": "44ee559b-b1da-4851-8c97-420ab394aff9",
"metadata": {},
"source": [
"## Contributing\n",
"\n",
"We appreciate all chat model integration contributions. \n",
"\n",
"Here's a checklist to help make sure your contribution gets added to LangChain:\n",
"\n",
"Documentation:\n",
"\n",
"* The model contains doc-strings for all initialization arguments, as these will be surfaced in the [APIReference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).\n",
"* The class doc-string for the model contains a link to the model API if the model is powered by a service.\n",
"\n",
"Tests:\n",
"\n",
"* [ ] Add unit or integration tests to the overridden methods. Verify that `invoke`, `ainvoke`, `batch`, `stream` work if you've over-ridden the corresponding code.\n",
"\n",
"Streaming (if you're implementing it):\n",
"\n",
"* [ ] Provided an async implementation via `_astream`\n",
"* [ ] Make sure to invoke the `on_llm_new_token` callback\n",
"* [ ] `on_llm_new_token` is invoked BEFORE yielding the chunk\n",
"\n",
"Stop Token Behavior:\n",
"\n",
"* [ ] Stop token should be respected\n",
"* [ ] Stop token should be INCLUDED as part of the response\n",
"\n",
"Secret API Keys:\n",
"\n",
"* [ ] If your model connects to an API it will likely accept API keys as part of its initialization. Use Pydantic's `SecretStr` type for secrets, so they don't get accidentally printed out when folks print the model."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -4,11 +4,13 @@ sidebar_position: 3
# Chat Models
ChatModels are a core component of LangChain.
LangChain does not serve its own ChatModels, but rather provides a standard interface for interacting with many different models. To be specific, this interface is one that takes as input a list of messages and returns a message.
Chat Models are a core component of LangChain.
A chat model is a language model that uses chat messages as inputs and returns chat messages as outputs (as opposed to using plain text).
There are lots of model providers (OpenAI, Cohere, Hugging Face, etc) - the `ChatModel` class is designed to provide a standard interface for all of them.
LangChain has integrations with many model providers (OpenAI, Cohere, Hugging Face, etc.) and exposes a standard interface to interact with all of these models.
LangChain allows you to use models in sync, async, batching and streaming modes and provides other features (e.g., caching) and more.
## [Quick Start](./quick_start)
@ -27,3 +29,4 @@ This includes:
- [How to use ChatModels that support function calling](./function_calling)
- [How to stream responses from a ChatModel](./streaming)
- [How to track token usage in a ChatModel call](./token_usage_tracking)
- [How to creat a custom ChatModel](./custom_chat_model)

@ -794,7 +794,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
class SimpleChatModel(BaseChatModel):
"""Simple Chat Model."""
"""A simplified implementation for a chat model to inherit from."""
def _generate(
self,

Loading…
Cancel
Save