community[minor]: add chat model llamacpp (#22589)

- **PR title**: [community] add chat model llamacpp


- **PR message**:
- **Description:** This PR introduces a new chat model integration with
llamacpp_python, designed to work similarly to the existing ChatOpenAI
model.
      + Work well with instructed chat, chain and function/tool calling.
+ Work with LangGraph (persistent memory, tool calling), will update
soon

- **Dependencies:** This change requires the llamacpp_python library to
be installed.
    
@baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
pull/22513/head
Thanh Nguyen 4 months ago committed by GitHub
parent e4279f80cd
commit b5e2ba3a47
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,595 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChatLlamaCpp\n",
"\n",
"This notebook provides a quick overview for getting started with chat model intergrated with [llama cpp python](https://github.com/abetlen/llama-cpp-python)\n",
"\n",
"An example below demonstrating how to implement with the open-source Llama3 Instruct 8B"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overview\n",
"\n",
"### Integration details\n",
"| Class | Package | Local | Serializable | JS support |\n",
"| :--- | :--- | :---: | :---: | :---: |\n",
"| [ChatLlamaCpp](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.llamacpp.ChatLlamaCpp.html) | [langchain-community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ✅ | ❌ | ❌ |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | Image input | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | \n",
"\n",
"## Setup\n",
"\n",
"### Installation\n",
"\n",
"The LangChain OpenAI integration lives in the `langchain-community` and `llama-cpp-python` packages:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-community llama-cpp-python"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"llama_model_loader: loaded meta data with 22 key-value pairs and 291 tensors from /home/tni5hc/Documents/langchain_llamacpp/SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf (version GGUF V3 (latest))\n",
"llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n",
"llama_model_loader: - kv 0: general.architecture str = llama\n",
"llama_model_loader: - kv 1: general.name str = Meta-Llama-3-8B-Instruct\n",
"llama_model_loader: - kv 2: llama.block_count u32 = 32\n",
"llama_model_loader: - kv 3: llama.context_length u32 = 8192\n",
"llama_model_loader: - kv 4: llama.embedding_length u32 = 4096\n",
"llama_model_loader: - kv 5: llama.feed_forward_length u32 = 14336\n",
"llama_model_loader: - kv 6: llama.attention.head_count u32 = 32\n",
"llama_model_loader: - kv 7: llama.attention.head_count_kv u32 = 8\n",
"llama_model_loader: - kv 8: llama.rope.freq_base f32 = 500000.000000\n",
"llama_model_loader: - kv 9: llama.attention.layer_norm_rms_epsilon f32 = 0.000010\n",
"llama_model_loader: - kv 10: general.file_type u32 = 7\n",
"llama_model_loader: - kv 11: llama.vocab_size u32 = 128256\n",
"llama_model_loader: - kv 12: llama.rope.dimension_count u32 = 128\n",
"llama_model_loader: - kv 13: tokenizer.ggml.model str = gpt2\n",
"llama_model_loader: - kv 14: tokenizer.ggml.pre str = llama-bpe\n",
"llama_model_loader: - kv 15: tokenizer.ggml.tokens arr[str,128256] = [\"!\", \"\\\"\", \"#\", \"$\", \"%\", \"&\", \"'\", ...\n",
"llama_model_loader: - kv 16: tokenizer.ggml.token_type arr[i32,128256] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...\n",
"llama_model_loader: - kv 17: tokenizer.ggml.merges arr[str,280147] = [\"Ġ Ġ\", \"Ġ ĠĠĠ\", \"ĠĠ ĠĠ\", \"...\n",
"llama_model_loader: - kv 18: tokenizer.ggml.bos_token_id u32 = 128000\n",
"llama_model_loader: - kv 19: tokenizer.ggml.eos_token_id u32 = 128009\n",
"llama_model_loader: - kv 20: tokenizer.chat_template str = {% set loop_messages = messages %}{% ...\n",
"llama_model_loader: - kv 21: general.quantization_version u32 = 2\n",
"llama_model_loader: - type f32: 65 tensors\n",
"llama_model_loader: - type q8_0: 226 tensors\n",
"llm_load_vocab: special tokens definition check successful ( 256/128256 ).\n",
"llm_load_print_meta: format = GGUF V3 (latest)\n",
"llm_load_print_meta: arch = llama\n",
"llm_load_print_meta: vocab type = BPE\n",
"llm_load_print_meta: n_vocab = 128256\n",
"llm_load_print_meta: n_merges = 280147\n",
"llm_load_print_meta: n_ctx_train = 8192\n",
"llm_load_print_meta: n_embd = 4096\n",
"llm_load_print_meta: n_head = 32\n",
"llm_load_print_meta: n_head_kv = 8\n",
"llm_load_print_meta: n_layer = 32\n",
"llm_load_print_meta: n_rot = 128\n",
"llm_load_print_meta: n_embd_head_k = 128\n",
"llm_load_print_meta: n_embd_head_v = 128\n",
"llm_load_print_meta: n_gqa = 4\n",
"llm_load_print_meta: n_embd_k_gqa = 1024\n",
"llm_load_print_meta: n_embd_v_gqa = 1024\n",
"llm_load_print_meta: f_norm_eps = 0.0e+00\n",
"llm_load_print_meta: f_norm_rms_eps = 1.0e-05\n",
"llm_load_print_meta: f_clamp_kqv = 0.0e+00\n",
"llm_load_print_meta: f_max_alibi_bias = 0.0e+00\n",
"llm_load_print_meta: f_logit_scale = 0.0e+00\n",
"llm_load_print_meta: n_ff = 14336\n",
"llm_load_print_meta: n_expert = 0\n",
"llm_load_print_meta: n_expert_used = 0\n",
"llm_load_print_meta: causal attn = 1\n",
"llm_load_print_meta: pooling type = 0\n",
"llm_load_print_meta: rope type = 0\n",
"llm_load_print_meta: rope scaling = linear\n",
"llm_load_print_meta: freq_base_train = 500000.0\n",
"llm_load_print_meta: freq_scale_train = 1\n",
"llm_load_print_meta: n_yarn_orig_ctx = 8192\n",
"llm_load_print_meta: rope_finetuned = unknown\n",
"llm_load_print_meta: ssm_d_conv = 0\n",
"llm_load_print_meta: ssm_d_inner = 0\n",
"llm_load_print_meta: ssm_d_state = 0\n",
"llm_load_print_meta: ssm_dt_rank = 0\n",
"llm_load_print_meta: model type = 7B\n",
"llm_load_print_meta: model ftype = Q8_0\n",
"llm_load_print_meta: model params = 8.03 B\n",
"llm_load_print_meta: model size = 7.95 GiB (8.50 BPW) \n",
"llm_load_print_meta: general.name = Meta-Llama-3-8B-Instruct\n",
"llm_load_print_meta: BOS token = 128000 '<|begin_of_text|>'\n",
"llm_load_print_meta: EOS token = 128009 '<|eot_id|>'\n",
"llm_load_print_meta: LF token = 128 'Ä'\n",
"ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no\n",
"ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes\n",
"ggml_cuda_init: found 1 CUDA devices:\n",
" Device 0: NVIDIA RTX A2000 12GB, compute capability 8.6, VMM: yes\n",
"llm_load_tensors: ggml ctx size = 0.22 MiB\n",
"llm_load_tensors: offloading 8 repeating layers to GPU\n",
"llm_load_tensors: offloaded 8/33 layers to GPU\n",
"llm_load_tensors: CPU buffer size = 8137.64 MiB\n",
"llm_load_tensors: CUDA0 buffer size = 1768.25 MiB\n",
".........................................................................................\n",
"llama_new_context_with_model: n_ctx = 10016\n",
"llama_new_context_with_model: n_batch = 300\n",
"llama_new_context_with_model: n_ubatch = 300\n",
"llama_new_context_with_model: freq_base = 10000.0\n",
"llama_new_context_with_model: freq_scale = 1\n",
"llama_kv_cache_init: CUDA_Host KV buffer size = 939.00 MiB\n",
"llama_kv_cache_init: CUDA0 KV buffer size = 313.00 MiB\n",
"llama_new_context_with_model: KV self size = 1252.00 MiB, K (f16): 626.00 MiB, V (f16): 626.00 MiB\n",
"llama_new_context_with_model: CUDA_Host output buffer size = 0.49 MiB\n",
"llama_new_context_with_model: CUDA0 compute buffer size = 683.78 MiB\n",
"llama_new_context_with_model: CUDA_Host compute buffer size = 16.15 MiB\n",
"llama_new_context_with_model: graph nodes = 1030\n",
"llama_new_context_with_model: graph splits = 268\n",
"AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | \n",
"Model metadata: {'tokenizer.chat_template': \"{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}\", 'tokenizer.ggml.eos_token_id': '128009', 'general.quantization_version': '2', 'tokenizer.ggml.model': 'gpt2', 'general.architecture': 'llama', 'llama.rope.freq_base': '500000.000000', 'tokenizer.ggml.pre': 'llama-bpe', 'llama.context_length': '8192', 'general.name': 'Meta-Llama-3-8B-Instruct', 'llama.embedding_length': '4096', 'llama.feed_forward_length': '14336', 'llama.attention.layer_norm_rms_epsilon': '0.000010', 'tokenizer.ggml.bos_token_id': '128000', 'llama.attention.head_count': '32', 'llama.block_count': '32', 'llama.attention.head_count_kv': '8', 'general.file_type': '7', 'llama.vocab_size': '128256', 'llama.rope.dimension_count': '128'}\n",
"Using gguf chat template: {% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n",
"\n",
"'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n",
"\n",
"' }}{% endif %}\n",
"Using chat eos_token: <|eot_id|>\n",
"Using chat bos_token: <|begin_of_text|>\n"
]
}
],
"source": [
"import multiprocessing\n",
"\n",
"from langchain_community.chat_models import ChatLlamaCpp\n",
"\n",
"llm = ChatLlamaCpp(\n",
" temperature=0.5,\n",
" model_path=\"./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf\",\n",
" n_ctx=10000,\n",
" n_gpu_layers=8,\n",
" n_batch=300, # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.\n",
" max_tokens=512,\n",
" n_threads=multiprocessing.cpu_count() - 1,\n",
" repeat_penalty=1.5,\n",
" top_p=0.5,\n",
" verbose=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"llama_print_timings: load time = 1077.71 ms\n",
"llama_print_timings: sample time = 21.82 ms / 39 runs ( 0.56 ms per token, 1787.35 tokens per second)\n",
"llama_print_timings: prompt eval time = 1077.65 ms / 37 tokens ( 29.13 ms per token, 34.33 tokens per second)\n",
"llama_print_timings: eval time = 8403.75 ms / 38 runs ( 221.15 ms per token, 4.52 tokens per second)\n",
"llama_print_timings: total time = 9689.66 ms / 75 tokens\n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content='Je adore le programmation.\\n\\n(Note: \"programmation\" is used in both formal and informal contexts, but it\\'s generally accepted as equivalent of saying you like computer science or coding.)', response_metadata={'finish_reason': 'stop'}, id='run-e9e03b94-f29f-4c1d-8483-e23a46acb556-0')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Je adore le programmation.\n",
"\n",
"(Note: \"programmation\" is used in both formal and informal contexts, but it's generally accepted as equivalent of saying you like computer science or coding.)\n"
]
}
],
"source": [
"print(ai_msg.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Llama.generate: prefix-match hit\n",
"\n",
"llama_print_timings: load time = 1077.71 ms\n",
"llama_print_timings: sample time = 29.23 ms / 52 runs ( 0.56 ms per token, 1778.75 tokens per second)\n",
"llama_print_timings: prompt eval time = 869.38 ms / 17 tokens ( 51.14 ms per token, 19.55 tokens per second)\n",
"llama_print_timings: eval time = 6694.18 ms / 51 runs ( 131.26 ms per token, 7.62 tokens per second)\n",
"llama_print_timings: total time = 7830.86 ms / 68 tokens\n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe auch Programmieren! (Translation: I also like coding!) Do you have any favorite languages or projects? Ich bin hier, um dir zu helfen und über deine Lieblingsprogrammierthemen sprechen können wir gerne weiter machen... !)', response_metadata={'finish_reason': 'stop'}, id='run-922c4cad-368f-41ba-9db9-eacb41d37cb2-0')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool calling\n",
"\n",
"Firstly, it works mostly the same as OpenAI Function Calling\n",
"\n",
"OpenAI has a [tool calling](https://platform.openai.com/docs/guides/function-calling) (we use \"tool calling\" and \"function calling\" interchangeably here) API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool. tool-calling is extremely useful for building tool-using chains and agents, and for getting structured outputs from models more generally.\n",
"\n",
"With `ChatLlamaCpp.bind_tools`, we can easily pass in Pydantic classes, dict schemas, LangChain tools, or even functions as tools to the model. Under the hood these are converted to an OpenAI tool schemas, which looks like:\n",
"```\n",
"{\n",
" \"name\": \"...\",\n",
" \"description\": \"...\",\n",
" \"parameters\": {...} # JSONSchema\n",
"}\n",
"```\n",
"and passed in every model invocation.\n",
"\n",
"\n",
"However, it cannot automatically trigger a function/tool, we need to force it by specifying the 'tool choice' parameter. This parameter is typically formatted as described below.\n",
"\n",
"```{\"type\": \"function\", \"function\": {\"name\": <<tool_name>>}}.```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from langchain.tools import tool\n",
"from langchain_core.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class WeatherInput(BaseModel):\n",
" location: str = Field(description=\"The city and state, e.g. San Francisco, CA\")\n",
" unit: str = Field(enum=[\"celsius\", \"fahrenheit\"])\n",
"\n",
"\n",
"@tool(\"get_current_weather\", args_schema=WeatherInput)\n",
"def get_weather(location: str, unit: str):\n",
" \"\"\"Get the current weather in a given location\"\"\"\n",
" return f\"Now the weather in {location} is 22 {unit}\"\n",
"\n",
"\n",
"llm_with_tools = llm.bind_tools(\n",
" tools=[get_weather],\n",
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"get_current_weather\"}},\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Llama.generate: prefix-match hit\n",
"\n",
"llama_print_timings: load time = 1077.71 ms\n",
"llama_print_timings: sample time = 853.67 ms / 20 runs ( 42.68 ms per token, 23.43 tokens per second)\n",
"llama_print_timings: prompt eval time = 1060.96 ms / 21 tokens ( 50.52 ms per token, 19.79 tokens per second)\n",
"llama_print_timings: eval time = 2754.74 ms / 19 runs ( 144.99 ms per token, 6.90 tokens per second)\n",
"llama_print_timings: total time = 4817.07 ms / 40 tokens\n"
]
},
{
"data": {
"text/plain": [
"AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_current_weather', 'arguments': '{ \"location\": \"Ho Chi Minh City\", \"unit\" : \"celsius\"}'}, 'tool_calls': [{'id': 'call__0_get_current_weather_cmpl-3e329fde-4fa6-41b9-837c-131fa9494554', 'type': 'function', 'function': {'name': 'get_current_weather', 'arguments': '{ \"location\": \"Ho Chi Minh City\", \"unit\" : \"celsius\"}'}}]}, response_metadata={'token_usage': {'prompt_tokens': 23, 'completion_tokens': 19, 'total_tokens': 42}, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-9d35869c-36fe-4f4a-835e-089a3f3aba3c-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'Ho Chi Minh City', 'unit': 'celsius'}, 'id': 'call__0_get_current_weather_cmpl-3e329fde-4fa6-41b9-837c-131fa9494554'}])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ai_msg = llm_with_tools.invoke(\n",
" \"what is the weather like in HCMC in celsius\",\n",
")\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'get_current_weather',\n",
" 'args': {'location': 'Ho Chi Minh City', 'unit': 'celsius'},\n",
" 'id': 'call__0_get_current_weather_cmpl-3e329fde-4fa6-41b9-837c-131fa9494554'}]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ai_msg.tool_calls"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Structured output"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Llama.generate: prefix-match hit\n",
"\n",
"llama_print_timings: load time = 1077.71 ms\n",
"llama_print_timings: sample time = 1964.76 ms / 44 runs ( 44.65 ms per token, 22.39 tokens per second)\n",
"llama_print_timings: prompt eval time = 914.34 ms / 18 tokens ( 50.80 ms per token, 19.69 tokens per second)\n",
"llama_print_timings: eval time = 7903.81 ms / 43 runs ( 183.81 ms per token, 5.44 tokens per second)\n",
"llama_print_timings: total time = 11065.60 ms / 61 tokens\n"
]
}
],
"source": [
"from langchain_core.pydantic_v1 import BaseModel\n",
"from langchain_core.utils.function_calling import convert_to_openai_tool\n",
"\n",
"\n",
"class AnswerWithJustification(BaseModel):\n",
" \"\"\"An answer to the user question along with justification for the answer.\"\"\"\n",
"\n",
" answer: str\n",
" justification: str\n",
"\n",
"\n",
"dict_schema = convert_to_openai_tool(AnswerWithJustification)\n",
"\n",
"structured_llm = llm.with_structured_output(dict_schema)\n",
"\n",
"result = structured_llm.invoke(\n",
" \"What weighs more a pound of bricks or a pound of feathers ?\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'answer': \"a pound is always the same weight, regardless of what it's made up off. So both options are equal in terms of their mass.\", 'justification': ''}\n"
]
}
],
"source": [
"print(result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Streaming\n"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Llama.generate: prefix-match hit\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"The\n",
" answer\n",
" to\n",
" the\n",
" multiplication\n",
" problem\n",
" \"\n",
"What\n",
"'s\n",
" \n",
"25\n",
" x\n",
" \n",
"5\n",
"?\"\n",
" would\n",
" be\n",
":\n",
"\n",
"\n",
"125\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"llama_print_timings: load time = 1077.71 ms\n",
"llama_print_timings: sample time = 10.60 ms / 20 runs ( 0.53 ms per token, 1886.26 tokens per second)\n",
"llama_print_timings: prompt eval time = 3661.75 ms / 12 tokens ( 305.15 ms per token, 3.28 tokens per second)\n",
"llama_print_timings: eval time = 2468.01 ms / 19 runs ( 129.90 ms per token, 7.70 tokens per second)\n",
"llama_print_timings: total time = 3133.11 ms / 31 tokens\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"for chunk in llm.stream(\"what is 25x5\"):\n",
" print(chunk.content, end=\"\\n\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatLlamaCpp features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.llamacpp.ChatLlamaCpp.html"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

@ -112,6 +112,13 @@ CHAT_MODEL_FEAT_TABLE = {
"package": "langchain-community",
"link": "/docs/integrations/chat/edenai/",
},
"ChatLlamaCpp": {
"tool_calling": True,
"structured_output": True,
"local": True,
"package": "langchain-community",
"link": "/docs/integrations/chat/llamacpp",
},
}

@ -105,6 +105,7 @@ if TYPE_CHECKING:
from langchain_community.chat_models.llama_edge import (
LlamaEdgeChatService,
)
from langchain_community.chat_models.llamacpp import ChatLlamaCpp
from langchain_community.chat_models.maritalk import (
ChatMaritalk,
)
@ -200,6 +201,7 @@ __all__ = [
"ChatYandexGPT",
"ChatYuan2",
"ChatZhipuAI",
"ChatLlamaCpp",
"ErnieBotChat",
"FakeListChatModel",
"GPTRouter",
@ -265,6 +267,7 @@ _module_lookup = {
"QianfanChatEndpoint": "langchain_community.chat_models.baidu_qianfan_endpoint",
"VolcEngineMaasChat": "langchain_community.chat_models.volcengine_maas",
"ChatPremAI": "langchain_community.chat_models.premai",
"ChatLlamaCpp": "langchain_community.chat_models.llamacpp",
}

@ -0,0 +1,811 @@
import json
from operator import itemgetter
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
cast,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import InvalidToolCall, ToolCall, ToolCallChunk
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
class ChatLlamaCpp(BaseChatModel):
"""llama.cpp model.
To use, you should have the llama-cpp-python library installed, and provide the
path to the Llama model as a named parameter to the constructor.
Check out: https://github.com/abetlen/llama-cpp-python
"""
client: Any #: :meta private:
model_path: str
"""The path to the Llama model file."""
lora_base: Optional[str] = None
"""The path to the Llama LoRA base model."""
lora_path: Optional[str] = None
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
n_ctx: int = 512
"""Token context window."""
n_parts: int = -1
"""Number of parts to split the model into.
If -1, the number of parts is automatically determined."""
seed: int = -1
"""Seed. If -1, a random seed is used."""
f16_kv: bool = True
"""Use half-precision for key/value cache."""
logits_all: bool = False
"""Return logits for all tokens, not just the last token."""
vocab_only: bool = False
"""Only load the vocabulary, no weights."""
use_mlock: bool = False
"""Force system to keep model in RAM."""
n_threads: Optional[int] = None
"""Number of threads to use.
If None, the number of threads is automatically determined."""
n_batch: int = 8
"""Number of tokens to process in parallel.
Should be a number between 1 and n_ctx."""
n_gpu_layers: Optional[int] = None
"""Number of layers to be loaded into gpu memory. Default None."""
suffix: Optional[str] = None
"""A suffix to append to the generated text. If None, no suffix is appended."""
max_tokens: int = 256
"""The maximum number of tokens to generate."""
temperature: float = 0.8
"""The temperature to use for sampling."""
top_p: float = 0.95
"""The top-p value to use for sampling."""
logprobs: Optional[int] = None
"""The number of logprobs to return. If None, no logprobs are returned."""
echo: bool = False
"""Whether to echo the prompt."""
stop: Optional[List[str]] = None
"""A list of strings to stop generation when encountered."""
repeat_penalty: float = 1.1
"""The penalty to apply to repeated tokens."""
top_k: int = 40
"""The top-k value to use for sampling."""
last_n_tokens_size: int = 64
"""The number of tokens to look back when applying the repeat_penalty."""
use_mmap: bool = True
"""Whether to keep the model loaded in RAM"""
rope_freq_scale: float = 1.0
"""Scale factor for rope sampling."""
rope_freq_base: float = 10000.0
"""Base frequency for rope sampling."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Any additional parameters to pass to llama_cpp.Llama."""
streaming: bool = True
"""Whether to stream the results, token by token."""
grammar_path: Optional[Union[str, Path]] = None
"""
grammar_path: Path to the .gbnf file that defines formal grammars
for constraining model outputs. For instance, the grammar can be used
to force the model to generate valid JSON or to speak exclusively in emojis. At most
one of grammar_path and grammar should be passed in.
"""
grammar: Any = None
"""
grammar: formal grammar for constraining model outputs. For instance, the grammar
can be used to force the model to generate valid JSON or to speak exclusively in
emojis. At most one of grammar_path and grammar should be passed in.
"""
verbose: bool = True
"""Print verbose output to stderr."""
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
try:
from llama_cpp import Llama, LlamaGrammar
except ImportError:
raise ImportError(
"Could not import llama-cpp-python library. "
"Please install the llama-cpp-python library to "
"use this embedding model: pip install llama-cpp-python"
)
model_path = values["model_path"]
model_param_names = [
"rope_freq_scale",
"rope_freq_base",
"lora_path",
"lora_base",
"n_ctx",
"n_parts",
"seed",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"n_threads",
"n_batch",
"use_mmap",
"last_n_tokens_size",
"verbose",
]
model_params = {k: values[k] for k in model_param_names}
# For backwards compatibility, only include if non-null.
if values["n_gpu_layers"] is not None:
model_params["n_gpu_layers"] = values["n_gpu_layers"]
model_params.update(values["model_kwargs"])
try:
values["client"] = Llama(model_path, **model_params)
except Exception as e:
raise ValueError(
f"Could not load Llama model from path: {model_path}. "
f"Received error {e}"
)
if values["grammar"] and values["grammar_path"]:
grammar = values["grammar"]
grammar_path = values["grammar_path"]
raise ValueError(
"Can only pass in one of grammar and grammar_path. Received "
f"{grammar=} and {grammar_path=}."
)
elif isinstance(values["grammar"], str):
values["grammar"] = LlamaGrammar.from_string(values["grammar"])
elif values["grammar_path"]:
values["grammar"] = LlamaGrammar.from_file(values["grammar_path"])
else:
pass
return values
def _get_parameters(self, stop: Optional[List[str]]) -> Dict[str, Any]:
"""
Performs sanity check, preparing parameters in format needed by llama_cpp.
Returns:
Dictionary containing the combined parameters.
"""
params = self._default_params
# llama_cpp expects the "stop" key not this, so we remove it:
stop_sequences = params.pop("stop_sequences")
# then sets it as configured, or default to an empty list:
params["stop"] = stop or stop_sequences or self.stop or []
return params
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts
def _create_chat_result(self, response: dict) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
token_usage = response.get("usage", {})
llm_output = {
"token_usage": token_usage,
# "system_fingerprint": response.get("system_fingerprint", ""),
}
return ChatResult(generations=generations, llm_output=llm_output)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = {**self._get_parameters(stop), **kwargs}
# Check tool_choice is whether available, if yes then run no stream with tool
# calling
if self.streaming and not params.get("tool_choice"):
stream_iter = self._stream(messages, run_manager=run_manager, **kwargs)
return generate_from_stream(stream_iter)
message_dicts = self._create_message_dicts(messages)
response = self.client.create_chat_completion(messages=message_dicts, **params)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = {**self._get_parameters(stop), **kwargs}
message_dicts = self._create_message_dicts(messages)
result = self.client.create_chat_completion(
messages=message_dicts, stream=True, **params
)
default_chunk_class = AIMessageChunk
count = 0
for chunk in result:
count += 1
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
if choice["delta"] is None:
continue
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[Union[Dict[str, Dict], bool, str]] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model
tool_choice: does not currently support "any", "auto" choices like OpenAI
tool-calling API. should be a dict of the form to force this tool
{"type": "function", "function": {"name": <<tool_name>>}}.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
tool_names = [ft["function"]["name"] for ft in formatted_tools]
if tool_choice:
if isinstance(tool_choice, dict):
if not any(
tool_choice["function"]["name"] == name for name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice=} was specified, but the only "
f"provided tools were {tool_names}."
)
elif isinstance(tool_choice, str):
chosen = [
f for f in formatted_tools if f["function"]["name"] == tool_choice
]
if not chosen:
raise ValueError(
f"Tool choice {tool_choice=} was specified, but the only "
f"provided tools were {tool_names}."
)
elif isinstance(tool_choice, bool):
if len(formatted_tools) > 1:
raise ValueError(
"tool_choice=True can only be specified when a single tool is "
f"passed in. Received {len(tools)} tools."
)
tool_choice = formatted_tools[0]
else:
raise ValueError(
"""Unrecognized tool_choice type. Expected dict having format like
this {"type": "function", "function": {"name": <<tool_name>>}}"""
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OpenAI function-calling spec or be a valid JSON schema
with top level 'title' and 'description' keys specified.
include_raw: If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs: Any other args to bind to model, ``self.bind(..., **kwargs)``.
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
Example: Pydantic schema (include_raw=False):
.. code-block:: python
from langchain_community.chat_models import ChatLlamaCpp
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatLlamaCpp(
temperature=0.,
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
n_ctx=10000,
n_gpu_layers=4,
n_batch=200,
max_tokens=512,
n_threads=multiprocessing.cpu_count() - 1,
repeat_penalty=1.5,
top_p=0.5,
stop=["<|end_of_text|>", "<|eot_id|>"],
)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'
# )
Example: Pydantic schema (include_raw=True):
.. code-block:: python
from langchain_community.chat_models import ChatLlamaCpp
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatLlamaCpp(
temperature=0.,
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
n_ctx=10000,
n_gpu_layers=4,
n_batch=200,
max_tokens=512,
n_threads=multiprocessing.cpu_count() - 1,
repeat_penalty=1.5,
top_p=0.5,
stop=["<|end_of_text|>", "<|eot_id|>"],
)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
# 'parsing_error': None
# }
Example: dict schema (include_raw=False):
.. code-block:: python
from langchain_community.chat_models import ChatLlamaCpp
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatLlamaCpp(
temperature=0.,
model_path="./SanctumAI-meta-llama-3-8b-instruct.Q8_0.gguf",
n_ctx=10000,
n_gpu_layers=4,
n_batch=200,
max_tokens=512,
n_threads=multiprocessing.cpu_count() - 1,
repeat_penalty=1.5,
top_p=0.5,
stop=["<|end_of_text|>", "<|eot_id|>"],
)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = isinstance(schema, type) and issubclass(schema, BaseModel)
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], tool_choice=True)
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[cast(Type, schema)], first_tool_only=True
)
else:
key_name = convert_to_openai_tool(schema)["function"]["name"]
output_parser = JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
**{"model_path": self.model_path},
**self._default_params,
}
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "llama-cpp-python"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling create_chat_completion."""
params: Dict = {
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"logprobs": self.logprobs,
"stop_sequences": self.stop, # key here is convention among LLM classes
"repeat_penalty": self.repeat_penalty,
}
if self.grammar:
params["grammar"] = self.grammar
return params
def _lc_tool_call_to_openai_tool_call(tool_call: ToolCall) -> dict:
return {
"type": "function",
"id": tool_call["id"],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["args"]),
},
}
def _lc_invalid_tool_call_to_openai_tool_call(
invalid_tool_call: InvalidToolCall,
) -> dict:
return {
"type": "function",
"id": invalid_tool_call["id"],
"function": {
"name": invalid_tool_call["name"],
"arguments": invalid_tool_call["args"],
},
}
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
"""Convert a dictionary to a LangChain message.
Args:
_dict: The dictionary.
Returns:
The LangChain message.
"""
role = _dict.get("role")
name = _dict.get("name")
id_ = _dict.get("id")
if role == "user":
return HumanMessage(content=_dict.get("content", ""), id=id_, name=name)
elif role == "assistant":
# Fix for azure
# Also OpenAI returns None for tool invocations
content = _dict.get("content", "") or ""
additional_kwargs: Dict = {}
if function_call := _dict.get("function_call"):
additional_kwargs["function_call"] = dict(function_call)
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tc = parse_tool_call(raw_tool_call, return_id=True)
except Exception as e:
invalid_tc = make_invalid_tool_call(raw_tool_call, str(e))
invalid_tool_calls.append(invalid_tc)
else:
if not tc:
continue
else:
tool_calls.append(tc)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
name=name,
id=id_,
tool_calls=tool_calls, # type: ignore[arg-type]
invalid_tool_calls=invalid_tool_calls,
)
elif role == "system":
return SystemMessage(content=_dict.get("content", ""), name=name, id=id_)
elif role == "function":
return FunctionMessage(
content=_dict.get("content", ""), name=cast(str, _dict.get("name")), id=id_
)
elif role == "tool":
additional_kwargs = {}
if "name" in _dict:
additional_kwargs["name"] = _dict["name"]
return ToolMessage(
content=_dict.get("content", ""),
tool_call_id=cast(str, _dict.get("tool_call_id")),
additional_kwargs=additional_kwargs,
name=name,
id=id_,
)
else:
return ChatMessage(
content=_dict.get("content", ""), role=cast(str, role), id=id_
)
def _format_message_content(content: Any) -> Any:
"""Format message content."""
if content and isinstance(content, list):
# Remove unexpected block types
formatted_content = []
for block in content:
if (
isinstance(block, dict)
and "type" in block
and block["type"] == "tool_use"
):
continue
else:
formatted_content.append(block)
else:
formatted_content = content
return formatted_content
def _convert_message_to_dict(message: BaseMessage) -> dict:
"""Convert a LangChain message to a dictionary.
Args:
message: The LangChain message.
Returns:
The dictionary.
"""
message_dict: Dict[str, Any] = {
"content": _format_message_content(message.content),
}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name
# populate role and additional message data
if isinstance(message, ChatMessage):
message_dict["role"] = message.role
elif isinstance(message, HumanMessage):
message_dict["role"] = "user"
elif isinstance(message, AIMessage):
message_dict["role"] = "assistant"
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
if message.tool_calls or message.invalid_tool_calls:
message_dict["tool_calls"] = [
_lc_tool_call_to_openai_tool_call(tc) for tc in message.tool_calls
] + [
_lc_invalid_tool_call_to_openai_tool_call(tc)
for tc in message.invalid_tool_calls
]
elif "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
tool_call_supported_props = {"id", "type", "function"}
message_dict["tool_calls"] = [
{k: v for k, v in tool_call.items() if k in tool_call_supported_props}
for tool_call in message_dict["tool_calls"]
]
else:
pass
# If tool calls present, content null value should be None not empty string.
if "function_call" in message_dict or "tool_calls" in message_dict:
message_dict["content"] = message_dict["content"] or None
elif isinstance(message, SystemMessage):
message_dict["role"] = "system"
elif isinstance(message, FunctionMessage):
message_dict["role"] = "function"
elif isinstance(message, ToolMessage):
message_dict["role"] = "tool"
message_dict["tool_call_id"] = message.tool_call_id
supported_props = {"content", "role", "tool_call_id"}
message_dict = {k: v for k, v in message_dict.items() if k in supported_props}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
id_ = _dict.get("id")
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for rtc in raw_tool_calls:
try:
tool_call = ToolCallChunk(
name=rtc["function"].get("name"),
args=rtc["function"].get("arguments"),
id=rtc.get("id"),
index=rtc["index"],
)
tool_call_chunks.append(tool_call)
except KeyError:
pass
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content, id=id_)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=id_,
tool_call_chunks=tool_call_chunks,
)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content, id=id_)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=id_
)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role, id=id_)
else:
return default_class(content=content, id=id_) # type: ignore

@ -21,6 +21,7 @@ EXPECTED_ALL = [
"ChatKonko",
"ChatLiteLLM",
"ChatLiteLLMRouter",
"ChatLlamaCpp",
"ChatMLflowAIGateway",
"ChatMaritalk",
"ChatMlflow",

Loading…
Cancel
Save