mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
partners: add xAI chat integration (#28032)
This commit is contained in:
parent
2898b95ca7
commit
48ee322a78
1
.github/workflows/_integration_test.yml
vendored
1
.github/workflows/_integration_test.yml
vendored
@ -79,6 +79,7 @@ jobs:
|
||||
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
UPSTAGE_API_KEY: ${{ secrets.UPSTAGE_API_KEY }}
|
||||
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
|
||||
run: |
|
||||
make integration_tests
|
||||
|
||||
|
1
.github/workflows/_release.yml
vendored
1
.github/workflows/_release.yml
vendored
@ -308,6 +308,7 @@ jobs:
|
||||
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
|
||||
UPSTAGE_API_KEY: ${{ secrets.UPSTAGE_API_KEY }}
|
||||
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
|
||||
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
|
||||
run: make integration_tests
|
||||
working-directory: ${{ inputs.working-directory }}
|
||||
|
||||
|
332
docs/docs/integrations/chat/xai.ipynb
Normal file
332
docs/docs/integrations/chat/xai.ipynb
Normal file
@ -0,0 +1,332 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"id": "afaf8039",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: xAI\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e49f1e0d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatXAI\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This page will help you get started with xAI [chat models](../../concepts/chat_models.mdx). For detailed documentation of all `ChatXAI` features and configurations head to the [API reference](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html).\n",
|
||||
"\n",
|
||||
"[xAI](https://console.x.ai/) offers an API to interact with Grok models.\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"### Integration details\n",
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/xai) | Package downloads | Package latest |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| [ChatXAI](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html) | [langchain-xai](https://python.langchain.com/api_reference/xai/index.html) | ❌ | beta | ✅ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-xai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-xai?style=flat-square&label=%20) |\n",
|
||||
"\n",
|
||||
"### Model features\n",
|
||||
"| [Tool calling](../../how_to/tool_calling.ipynb) | [Structured output](../../how_to/structured_output.ipynb) | JSON mode | [Image input](../../how_to/multimodal_inputs.ipynb) | Audio input | Video input | [Token-level streaming](../../how_to/chat_streaming.ipynb) | Native async | [Token usage](../../how_to/chat_token_usage_tracking.ipynb) | [Logprobs](../../how_to/logprobs.ipynb) |\n",
|
||||
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ✅ | \n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"To access xAI models you'll need to create an xAI account, get an API key, and install the `langchain-xai` integration package.\n",
|
||||
"\n",
|
||||
"### Credentials\n",
|
||||
"\n",
|
||||
"Head to [this page](https://console.x.ai/) to sign up for xAI and generate an API key. Once you've done this set the `XAI_API_KEY` environment variable:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if \"XAI_API_KEY\" not in os.environ:\n",
|
||||
" os.environ[\"XAI_API_KEY\"] = getpass.getpass(\"Enter your xAI API key: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "72ee0c4b-9764-423a-9dbf-95129e185210",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
|
||||
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0730d6a1-c893-4840-9817-5e5251676d5d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Installation\n",
|
||||
"\n",
|
||||
"The LangChain xAI integration lives in the `langchain-xai` package:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain-xai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Instantiation\n",
|
||||
"\n",
|
||||
"Now we can instantiate our model object and generate chat completions:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_xai import ChatXAI\n",
|
||||
"\n",
|
||||
"llm = ChatXAI(\n",
|
||||
" model=\"grok-beta\",\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=None,\n",
|
||||
" timeout=None,\n",
|
||||
" max_retries=2,\n",
|
||||
" # other params...\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2b4f3e15",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invocation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "62e0dbc3",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"J'adore programmer.\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 30, 'total_tokens': 36, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'grok-beta', 'system_fingerprint': 'fp_14b89b2dfc', 'finish_reason': 'stop', 'logprobs': None}, id='run-adffb7a3-e48a-4f52-b694-340d85abe5c3-0', usage_metadata={'input_tokens': 30, 'output_tokens': 6, 'total_tokens': 36, 'input_token_details': {}, 'output_token_details': {}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"I love programming.\"),\n",
|
||||
"]\n",
|
||||
"ai_msg = llm.invoke(messages)\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"J'adore programmer.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(ai_msg.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chaining\n",
|
||||
"\n",
|
||||
"We can [chain](../../how_to/sequence.ipynb) our model with a prompt template like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Ich liebe das Programmieren.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 7, 'prompt_tokens': 25, 'total_tokens': 32, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'grok-beta', 'system_fingerprint': 'fp_14b89b2dfc', 'finish_reason': 'stop', 'logprobs': None}, id='run-569fc8dc-101b-4e6d-864e-d4fa80df2b63-0', usage_metadata={'input_tokens': 25, 'output_tokens': 7, 'total_tokens': 32, 'input_token_details': {}, 'output_token_details': {}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"{input}\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"input_language\": \"English\",\n",
|
||||
" \"output_language\": \"German\",\n",
|
||||
" \"input\": \"I love programming.\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e074bce1-0994-4b83-b393-ae7aa7e21750",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tool calling\n",
|
||||
"\n",
|
||||
"ChatXAI has a [tool calling](https://docs.x.ai/docs#capabilities) (we use \"tool calling\" and \"function calling\" interchangeably here) API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool. Tool-calling is extremely useful for building tool-using chains and agents, and for getting structured outputs from models more generally.\n",
|
||||
"\n",
|
||||
"### ChatXAI.bind_tools()\n",
|
||||
"\n",
|
||||
"With `ChatXAI.bind_tools`, we can easily pass in Pydantic classes, dict schemas, LangChain tools, or even functions as tools to the model. Under the hood these are converted to an OpenAI tool schemas, which looks like:\n",
|
||||
"```\n",
|
||||
"{\n",
|
||||
" \"name\": \"...\",\n",
|
||||
" \"description\": \"...\",\n",
|
||||
" \"parameters\": {...} # JSONSchema\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"and passed in every model invocation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "c6bfe929-ec02-46bd-9d54-76350edddabc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GetWeather(BaseModel):\n",
|
||||
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||
"\n",
|
||||
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"llm_with_tools = llm.bind_tools([GetWeather])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "5265c892-d8c2-48af-aef5-adbee1647ba6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='I am retrieving the current weather for San Francisco.', additional_kwargs={'tool_calls': [{'id': '0', 'function': {'arguments': '{\"location\":\"San Francisco, CA\"}', 'name': 'GetWeather'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 11, 'prompt_tokens': 151, 'total_tokens': 162, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'grok-beta', 'system_fingerprint': 'fp_14b89b2dfc', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-73707da7-afec-4a52-bee1-a176b0ab8585-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': '0', 'type': 'tool_call'}], usage_metadata={'input_tokens': 151, 'output_tokens': 11, 'total_tokens': 162, 'input_token_details': {}, 'output_token_details': {}})"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ai_msg = llm_with_tools.invoke(\n",
|
||||
" \"what is the weather like in San Francisco\",\n",
|
||||
")\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all `ChatXAI` features and configurations head to the API reference: https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
97
docs/docs/integrations/providers/xai.ipynb
Normal file
97
docs/docs/integrations/providers/xai.ipynb
Normal file
@ -0,0 +1,97 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2970dd75-8ebf-4b51-8282-9b454b8f356d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# xAI\n",
|
||||
"\n",
|
||||
"[xAI](https://console.x.ai) offers an API to interact with Grok models.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with xAI models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1c47fc36",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Installation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1ecdb29d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install --upgrade langchain-xai"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "89883202",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Environment\n",
|
||||
"\n",
|
||||
"To use xAI, you'll need to [create an API key](https://console.x.ai/). The API key can be passed in as an init param ``xai_api_key`` or set as environment variable ``XAI_API_KEY``."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8304b4d9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "637bb53f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Querying chat models with xAI\n",
|
||||
"\n",
|
||||
"from langchain_xai import ChatXAI\n",
|
||||
"\n",
|
||||
"chat = ChatXAI(\n",
|
||||
" # xai_api_key=\"YOUR_API_KEY\",\n",
|
||||
" model=\"grok-beta\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# stream the response back from the model\n",
|
||||
"for m in chat.stream(\"Tell me fun things to do in NYC\"):\n",
|
||||
" print(m.content, end=\"\", flush=True)\n",
|
||||
"\n",
|
||||
"# if you don't want to do streaming, you can use the invoke method\n",
|
||||
"# chat.invoke(\"Tell me fun things to do in NYC\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -227,6 +227,17 @@ const FEATURE_TABLES = {
|
||||
"local": false,
|
||||
"apiLink": "https://python.langchain.com/api_reference/ibm/chat_models/langchain_ibm.chat_models.ChatWatsonx.html"
|
||||
},
|
||||
{
|
||||
"name": "ChatXAI",
|
||||
"package": "langchain-xai",
|
||||
"link": "xai",
|
||||
"structured_output": true,
|
||||
"tool_calling": true,
|
||||
"json_mode": false,
|
||||
"multimodal": false,
|
||||
"local": false,
|
||||
"apiLink": "https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
|
||||
},
|
||||
],
|
||||
},
|
||||
llms: {
|
||||
|
@ -24,6 +24,7 @@ DEFAULT_NAMESPACES = [
|
||||
"langchain_google_vertexai",
|
||||
"langchain_mistralai",
|
||||
"langchain_fireworks",
|
||||
"langchain_xai",
|
||||
]
|
||||
# Namespaces for which only deserializing via the SERIALIZABLE_MAPPING is allowed.
|
||||
# Load by path is not allowed.
|
||||
|
21
libs/partners/xai/LICENSE
Normal file
21
libs/partners/xai/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
56
libs/partners/xai/Makefile
Normal file
56
libs/partners/xai/Makefile
Normal file
@ -0,0 +1,56 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/xai --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_xai
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_xai -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
17
libs/partners/xai/README.md
Normal file
17
libs/partners/xai/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# langchain-xai
|
||||
|
||||
This package contains the LangChain integrations for [xAI](https://x.ai/) through their [APIs](https://console.x.ai).
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the LangChain partner package
|
||||
|
||||
```bash
|
||||
pip install -U langchain-xai
|
||||
```
|
||||
|
||||
- Get your xAI api key from the [xAI Dashboard](https://console.x.ai) and set it as an environment variable (`XAI_API_KEY`)
|
||||
|
||||
## Chat Completions
|
||||
|
||||
This package contains the `ChatXAI` class, which is the recommended way to interface with xAI chat models.
|
5
libs/partners/xai/langchain_xai/__init__.py
Normal file
5
libs/partners/xai/langchain_xai/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""This package provides the xAI integration for LangChain."""
|
||||
|
||||
from langchain_xai.chat_models import ChatXAI
|
||||
|
||||
__all__ = ["ChatXAI"]
|
355
libs/partners/xai/langchain_xai/chat_models.py
Normal file
355
libs/partners/xai/langchain_xai/chat_models.py
Normal file
@ -0,0 +1,355 @@
|
||||
"""Wrapper around xAI's Chat Completions API."""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.utils import secret_from_env
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
from pydantic import ConfigDict, Field, SecretStr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
|
||||
r"""ChatXAI chat model.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-xai
|
||||
export XAI_API_KEY="your-api-key"
|
||||
|
||||
|
||||
Key init args — completion params:
|
||||
model: str
|
||||
Name of model to use.
|
||||
temperature: float
|
||||
Sampling temperature.
|
||||
max_tokens: Optional[int]
|
||||
Max number of tokens to generate.
|
||||
logprobs: Optional[bool]
|
||||
Whether to return logprobs.
|
||||
|
||||
Key init args — client params:
|
||||
timeout: Union[float, Tuple[float, float], Any, None]
|
||||
Timeout for requests.
|
||||
max_retries: int
|
||||
Max number of retries.
|
||||
api_key: Optional[str]
|
||||
xAI API key. If not passed in will be read from env var `XAI_API_KEY`.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
llm = ChatXAI(
|
||||
model="grok-beta",
|
||||
temperature=0,
|
||||
max_tokens=None,
|
||||
timeout=None,
|
||||
max_retries=2,
|
||||
# api_key="...",
|
||||
# other params...
|
||||
)
|
||||
|
||||
Invoke:
|
||||
.. code-block:: python
|
||||
|
||||
messages = [
|
||||
(
|
||||
"system",
|
||||
"You are a helpful translator. Translate the user sentence to French.",
|
||||
),
|
||||
("human", "I love programming."),
|
||||
]
|
||||
llm.invoke(messages)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
|
||||
'model_name': 'grok-beta',
|
||||
'system_fingerprint': None,
|
||||
'finish_reason': 'stop',
|
||||
'logprobs': None
|
||||
},
|
||||
id='run-168dceca-3b8b-4283-94e3-4c739dbc1525-0',
|
||||
usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41})
|
||||
|
||||
Stream:
|
||||
.. code-block:: python
|
||||
|
||||
for chunk in llm.stream(messages):
|
||||
print(chunk)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
content='J' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content="'" id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='ad' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='ore' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content=' la' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content=' programm' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='ation' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='.' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-beta'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
|
||||
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
await llm.ainvoke(messages)
|
||||
|
||||
# stream:
|
||||
# async for chunk in (await llm.astream(messages))
|
||||
|
||||
# batch:
|
||||
# await llm.abatch([messages])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
|
||||
'model_name': 'grok-beta',
|
||||
'system_fingerprint': None,
|
||||
'finish_reason': 'stop',
|
||||
'logprobs': None
|
||||
},
|
||||
id='run-09371a11-7f72-4c53-8e7c-9de5c238b34c-0',
|
||||
usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41})
|
||||
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
llm = ChatXAI(model="grok-beta")
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
'''Get the current population in a given location'''
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
ai_msg = llm_with_tools.invoke(
|
||||
"Which city is bigger: LA or NY?"
|
||||
)
|
||||
ai_msg.tool_calls
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
{
|
||||
'name': 'GetPopulation',
|
||||
'args': {'location': 'NY'},
|
||||
'id': 'call_m5tstyn2004pre9bfuxvom8x',
|
||||
'type': 'tool_call'
|
||||
},
|
||||
{
|
||||
'name': 'GetPopulation',
|
||||
'args': {'location': 'LA'},
|
||||
'id': 'call_0vjgq455gq1av5sp9eb1pw6a',
|
||||
'type': 'tool_call'
|
||||
}
|
||||
]
|
||||
|
||||
Structured output:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Joke(BaseModel):
|
||||
'''Joke to tell user.'''
|
||||
|
||||
setup: str = Field(description="The setup of the joke")
|
||||
punchline: str = Field(description="The punchline to the joke")
|
||||
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
|
||||
|
||||
|
||||
structured_llm = llm.with_structured_output(Joke)
|
||||
structured_llm.invoke("Tell me a joke about cats")
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Joke(
|
||||
setup='Why was the cat sitting on the computer?',
|
||||
punchline='To keep an eye on the mouse!',
|
||||
rating=7
|
||||
)
|
||||
|
||||
Token usage:
|
||||
.. code-block:: python
|
||||
|
||||
ai_msg = llm.invoke(messages)
|
||||
ai_msg.usage_metadata
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{'input_tokens': 37, 'output_tokens': 6, 'total_tokens': 43}
|
||||
|
||||
Logprobs:
|
||||
.. code-block:: python
|
||||
|
||||
logprobs_llm = llm.bind(logprobs=True)
|
||||
messages=[("human","Say Hello World! Do not return anything else.")]
|
||||
ai_msg = logprobs_llm.invoke(messages)
|
||||
ai_msg.response_metadata["logprobs"]
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'content': None,
|
||||
'token_ids': [22557, 3304, 28808, 2],
|
||||
'tokens': [' Hello', ' World', '!', '</s>'],
|
||||
'token_logprobs': [-4.7683716e-06, -5.9604645e-07, 0, -0.057373047]
|
||||
}
|
||||
|
||||
|
||||
Response metadata
|
||||
.. code-block:: python
|
||||
|
||||
ai_msg = llm.invoke(messages)
|
||||
ai_msg.response_metadata
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'token_usage': {
|
||||
'completion_tokens': 4,
|
||||
'prompt_tokens': 19,
|
||||
'total_tokens': 23
|
||||
},
|
||||
'model_name': 'grok-beta',
|
||||
'system_fingerprint': None,
|
||||
'finish_reason': 'stop',
|
||||
'logprobs': None
|
||||
}
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
model_name: str = Field(alias="model")
|
||||
"""Model name to use."""
|
||||
xai_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("XAI_API_KEY", default=None),
|
||||
)
|
||||
"""xAI API key.
|
||||
|
||||
Automatically read from env variable `XAI_API_KEY` if not provided.
|
||||
"""
|
||||
xai_api_base: str = Field(default="https://api.x.ai/v1/")
|
||||
"""Base URL path for API requests."""
|
||||
|
||||
openai_api_key: Optional[SecretStr] = None
|
||||
openai_api_base: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
"""A map of constructor argument names to secret ids.
|
||||
|
||||
For example,
|
||||
{"xai_api_key": "XAI_API_KEY"}
|
||||
"""
|
||||
return {"xai_api_key": "XAI_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain_xai", "chat_models"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
"""List of attribute names that should be included in the serialized kwargs.
|
||||
|
||||
These attributes must be accepted by the constructor.
|
||||
"""
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.xai_api_base:
|
||||
attributes["xai_api_base"] = self.xai_api_base
|
||||
|
||||
return attributes
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "xai-chat"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "xai"
|
||||
return params
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if self.n < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if self.n > 1 and self.streaming:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
client_params: dict = {
|
||||
"api_key": (
|
||||
self.xai_api_key.get_secret_value() if self.xai_api_key else None
|
||||
),
|
||||
"base_url": self.xai_api_base,
|
||||
"timeout": self.request_timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
|
||||
if client_params["api_key"] is None:
|
||||
raise ValueError(
|
||||
"xAI API key is not set. Please set it in the `xai_api_key` field or "
|
||||
"in the `XAI_API_KEY` environment variable."
|
||||
)
|
||||
|
||||
if not (self.client or None):
|
||||
sync_specific: dict = {"http_client": self.http_client}
|
||||
self.client = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
if not (self.async_client or None):
|
||||
async_specific: dict = {"http_client": self.http_async_client}
|
||||
self.async_client = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
return self
|
0
libs/partners/xai/langchain_xai/py.typed
Normal file
0
libs/partners/xai/langchain_xai/py.typed
Normal file
2074
libs/partners/xai/poetry.lock
generated
Normal file
2074
libs/partners/xai/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
100
libs/partners/xai/pyproject.toml
Normal file
100
libs/partners/xai/pyproject.toml
Normal file
@ -0,0 +1,100 @@
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "langchain-xai"
|
||||
version = "0.1.0"
|
||||
description = "An integration package connecting xAI and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/xai"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-xai%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9,<4.0"
|
||||
langchain-openai = "^0.2"
|
||||
langchain-core = "^0.3"
|
||||
requests = "^2"
|
||||
aiohttp = "^3.9.1"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "D"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.typing]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**" = ["D"]
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
docarray = "^0.32.1"
|
||||
langchain-openai = { path = "../openai", develop = true }
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
langchain-standard-tests = { path = "../../standard-tests", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
[[tool.poetry.group.test_integration.dependencies.numpy]]
|
||||
version = "^1"
|
||||
python = "<3.12"
|
||||
|
||||
[[tool.poetry.group.test_integration.dependencies.numpy]]
|
||||
version = "^1.26.0"
|
||||
python = ">=3.12"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1.10"
|
||||
types-requests = "^2"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
19
libs/partners/xai/scripts/check_imports.py
Normal file
19
libs/partners/xai/scripts/check_imports.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""This module checks if the given python files can be imported without error."""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_failure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
17
libs/partners/xai/scripts/lint_imports.sh
Executable file
17
libs/partners/xai/scripts/lint_imports.sh
Executable file
@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/xai/tests/__init__.py
Normal file
0
libs/partners/xai/tests/__init__.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelIntegrationTests, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
# Initialize the rate limiter in global scope, so it can be re-used
|
||||
# across tests.
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=0.5,
|
||||
)
|
||||
|
||||
|
||||
class TestXAIStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatXAI
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "grok-beta",
|
||||
"rate_limiter": rate_limiter,
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "tool_name"
|
||||
|
||||
@pytest.mark.xfail(reason="Not yet supported.")
|
||||
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
|
||||
super().test_usage_metadata_streaming(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Can't handle AIMessage with empty content.")
|
||||
def test_tool_message_error_status(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_message_error_status(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Can't handle AIMessage with empty content.")
|
||||
def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
|
||||
super().test_structured_few_shot_examples(model)
|
||||
|
||||
@pytest.mark.xfail(reason="Can't handle AIMessage with empty content.")
|
||||
def test_tool_message_histories_string_content(self, model: BaseChatModel) -> None:
|
||||
super().test_tool_message_histories_string_content(model)
|
@ -0,0 +1,7 @@
|
||||
import pytest # type: ignore[import-not-found]
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
3
libs/partners/xai/tests/unit_tests/__init__.py
Normal file
3
libs/partners/xai/tests/unit_tests/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
import os
|
||||
|
||||
os.environ["XAI_API_KEY"] = "test"
|
@ -0,0 +1,31 @@
|
||||
# serializer version: 1
|
||||
# name: TestXAIStandard.test_serdes[serialized]
|
||||
dict({
|
||||
'id': list([
|
||||
'langchain_xai',
|
||||
'chat_models',
|
||||
'ChatXAI',
|
||||
]),
|
||||
'kwargs': dict({
|
||||
'max_retries': 2,
|
||||
'max_tokens': 100,
|
||||
'model_name': 'grok-beta',
|
||||
'n': 1,
|
||||
'request_timeout': 60.0,
|
||||
'stop': list([
|
||||
]),
|
||||
'temperature': 0.0,
|
||||
'xai_api_base': 'https://api.x.ai/v1/',
|
||||
'xai_api_key': dict({
|
||||
'id': list([
|
||||
'XAI_API_KEY',
|
||||
]),
|
||||
'lc': 1,
|
||||
'type': 'secret',
|
||||
}),
|
||||
}),
|
||||
'lc': 1,
|
||||
'name': 'ChatXAI',
|
||||
'type': 'constructor',
|
||||
})
|
||||
# ---
|
129
libs/partners/xai/tests/unit_tests/test_chat_models.py
Normal file
129
libs/partners/xai/tests/unit_tests/test_chat_models.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
|
||||
import pytest # type: ignore[import-not-found]
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_openai.chat_models.base import (
|
||||
_convert_dict_to_message,
|
||||
_convert_message_to_dict,
|
||||
)
|
||||
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatXAI(model="grok-beta")
|
||||
|
||||
|
||||
def test_xai_model_param() -> None:
|
||||
llm = ChatXAI(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatXAI(model_name="foo") # type: ignore[call-arg]
|
||||
assert llm.model_name == "foo"
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_provider"] == "xai"
|
||||
|
||||
|
||||
def test_chat_xai_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatXAI(
|
||||
model="grok-beta",
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_xai_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat xai."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatXAI(model="grok-beta", foo=3, max_tokens=10) # type: ignore[call-arg]
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatXAI(model="grok-beta", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatXAI(model="grok-beta", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_function_dict_to_message_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = _convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
assert isinstance(result, FunctionMessage)
|
||||
assert result.name == name
|
||||
assert result.content == content
|
||||
|
||||
|
||||
def test_convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test__convert_dict_to_message_human_with_name() -> None:
|
||||
message = {"role": "user", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai_with_name() -> None:
|
||||
message = {"role": "assistant", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system_with_name() -> None:
|
||||
message = {"role": "system", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_tool() -> None:
|
||||
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = ToolMessage(content="foo", tool_call_id="bar")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
@ -0,0 +1,35 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Tuple, Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
|
||||
ChatModelUnitTests, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
|
||||
class TestXAIStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatXAI
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "grok-beta"}
|
||||
|
||||
@property
|
||||
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
|
||||
return (
|
||||
{
|
||||
"XAI_API_KEY": "api_key",
|
||||
},
|
||||
{
|
||||
"model": "grok-beta",
|
||||
},
|
||||
{
|
||||
"xai_api_key": "api_key",
|
||||
"xai_api_base": "https://api.x.ai/v1/",
|
||||
},
|
||||
)
|
7
libs/partners/xai/tests/unit_tests/test_imports.py
Normal file
7
libs/partners/xai/tests/unit_tests/test_imports.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain_xai import __all__
|
||||
|
||||
EXPECTED_ALL = ["ChatXAI"]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
7
libs/partners/xai/tests/unit_tests/test_secrets.py
Normal file
7
libs/partners/xai/tests/unit_tests/test_secrets.py
Normal file
@ -0,0 +1,7 @@
|
||||
from langchain_xai import ChatXAI
|
||||
|
||||
|
||||
def test_chat_xai_secrets() -> None:
|
||||
o = ChatXAI(model="grok-beta", xai_api_key="foo") # type: ignore[call-arg]
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
Loading…
Reference in New Issue
Block a user