partners: add xAI chat integration (#28032)

This commit is contained in:
Vadym Barda 2024-11-12 15:11:29 -05:00 committed by GitHub
parent 2898b95ca7
commit 48ee322a78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 3378 additions and 0 deletions

View File

@ -79,6 +79,7 @@ jobs:
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
UPSTAGE_API_KEY: ${{ secrets.UPSTAGE_API_KEY }}
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
run: |
make integration_tests

View File

@ -308,6 +308,7 @@ jobs:
VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }}
UPSTAGE_API_KEY: ${{ secrets.UPSTAGE_API_KEY }}
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
XAI_API_KEY: ${{ secrets.XAI_API_KEY }}
run: make integration_tests
working-directory: ${{ inputs.working-directory }}

View File

@ -0,0 +1,332 @@
{
"cells": [
{
"cell_type": "raw",
"id": "afaf8039",
"metadata": {},
"source": [
"---\n",
"sidebar_label: xAI\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "e49f1e0d",
"metadata": {},
"source": [
"# ChatXAI\n",
"\n",
"\n",
"This page will help you get started with xAI [chat models](../../concepts/chat_models.mdx). For detailed documentation of all `ChatXAI` features and configurations head to the [API reference](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html).\n",
"\n",
"[xAI](https://console.x.ai/) offers an API to interact with Grok models.\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/docs/integrations/chat/xai) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatXAI](https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html) | [langchain-xai](https://python.langchain.com/api_reference/xai/index.html) | ❌ | beta | ✅ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-xai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-xai?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](../../how_to/tool_calling.ipynb) | [Structured output](../../how_to/structured_output.ipynb) | JSON mode | [Image input](../../how_to/multimodal_inputs.ipynb) | Audio input | Video input | [Token-level streaming](../../how_to/chat_streaming.ipynb) | Native async | [Token usage](../../how_to/chat_token_usage_tracking.ipynb) | [Logprobs](../../how_to/logprobs.ipynb) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ✅ | \n",
"\n",
"## Setup\n",
"\n",
"To access xAI models you'll need to create an xAI account, get an API key, and install the `langchain-xai` integration package.\n",
"\n",
"### Credentials\n",
"\n",
"Head to [this page](https://console.x.ai/) to sign up for xAI and generate an API key. Once you've done this set the `XAI_API_KEY` environment variable:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if \"XAI_API_KEY\" not in os.environ:\n",
" os.environ[\"XAI_API_KEY\"] = getpass.getpass(\"Enter your xAI API key: \")"
]
},
{
"cell_type": "markdown",
"id": "72ee0c4b-9764-423a-9dbf-95129e185210",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a15d341e-3e26-4ca3-830b-5aab30ed66de",
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")\n",
"# os.environ[\"LANGSMITH_TRACING\"] = \"true\""
]
},
{
"cell_type": "markdown",
"id": "0730d6a1-c893-4840-9817-5e5251676d5d",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain xAI integration lives in the `langchain-xai` package:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-xai"
]
},
{
"cell_type": "markdown",
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_xai import ChatXAI\n",
"\n",
"llm = ChatXAI(\n",
" model=\"grok-beta\",\n",
" temperature=0,\n",
" max_tokens=None,\n",
" timeout=None,\n",
" max_retries=2,\n",
" # other params...\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2b4f3e15",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "62e0dbc3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore programmer.\", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 30, 'total_tokens': 36, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'grok-beta', 'system_fingerprint': 'fp_14b89b2dfc', 'finish_reason': 'stop', 'logprobs': None}, id='run-adffb7a3-e48a-4f52-b694-340d85abe5c3-0', usage_metadata={'input_tokens': 30, 'output_tokens': 6, 'total_tokens': 36, 'input_token_details': {}, 'output_token_details': {}})"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"J'adore programmer.\n"
]
}
],
"source": [
"print(ai_msg.content)"
]
},
{
"cell_type": "markdown",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](../../how_to/sequence.ipynb) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe das Programmieren.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 7, 'prompt_tokens': 25, 'total_tokens': 32, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'grok-beta', 'system_fingerprint': 'fp_14b89b2dfc', 'finish_reason': 'stop', 'logprobs': None}, id='run-569fc8dc-101b-4e6d-864e-d4fa80df2b63-0', usage_metadata={'input_tokens': 25, 'output_tokens': 7, 'total_tokens': 32, 'input_token_details': {}, 'output_token_details': {}})"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e074bce1-0994-4b83-b393-ae7aa7e21750",
"metadata": {},
"source": [
"## Tool calling\n",
"\n",
"ChatXAI has a [tool calling](https://docs.x.ai/docs#capabilities) (we use \"tool calling\" and \"function calling\" interchangeably here) API that lets you describe tools and their arguments, and have the model return a JSON object with a tool to invoke and the inputs to that tool. Tool-calling is extremely useful for building tool-using chains and agents, and for getting structured outputs from models more generally.\n",
"\n",
"### ChatXAI.bind_tools()\n",
"\n",
"With `ChatXAI.bind_tools`, we can easily pass in Pydantic classes, dict schemas, LangChain tools, or even functions as tools to the model. Under the hood these are converted to an OpenAI tool schemas, which looks like:\n",
"```\n",
"{\n",
" \"name\": \"...\",\n",
" \"description\": \"...\",\n",
" \"parameters\": {...} # JSONSchema\n",
"}\n",
"```\n",
"and passed in every model invocation."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c6bfe929-ec02-46bd-9d54-76350edddabc",
"metadata": {},
"outputs": [],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"\n",
"class GetWeather(BaseModel):\n",
" \"\"\"Get the current weather in a given location\"\"\"\n",
"\n",
" location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n",
"\n",
"\n",
"llm_with_tools = llm.bind_tools([GetWeather])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5265c892-d8c2-48af-aef5-adbee1647ba6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='I am retrieving the current weather for San Francisco.', additional_kwargs={'tool_calls': [{'id': '0', 'function': {'arguments': '{\"location\":\"San Francisco, CA\"}', 'name': 'GetWeather'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 11, 'prompt_tokens': 151, 'total_tokens': 162, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'grok-beta', 'system_fingerprint': 'fp_14b89b2dfc', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-73707da7-afec-4a52-bee1-a176b0ab8585-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'San Francisco, CA'}, 'id': '0', 'type': 'tool_call'}], usage_metadata={'input_tokens': 151, 'output_tokens': 11, 'total_tokens': 162, 'input_token_details': {}, 'output_token_details': {}})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ai_msg = llm_with_tools.invoke(\n",
" \"what is the weather like in San Francisco\",\n",
")\n",
"ai_msg"
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all `ChatXAI` features and configurations head to the API reference: https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,97 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2970dd75-8ebf-4b51-8282-9b454b8f356d",
"metadata": {},
"source": [
"# xAI\n",
"\n",
"[xAI](https://console.x.ai) offers an API to interact with Grok models.\n",
"\n",
"This example goes over how to use LangChain to interact with xAI models."
]
},
{
"cell_type": "markdown",
"id": "1c47fc36",
"metadata": {},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ecdb29d",
"metadata": {},
"outputs": [],
"source": [
"%pip install --upgrade langchain-xai"
]
},
{
"cell_type": "markdown",
"id": "89883202",
"metadata": {},
"source": [
"## Environment\n",
"\n",
"To use xAI, you'll need to [create an API key](https://console.x.ai/). The API key can be passed in as an init param ``xai_api_key`` or set as environment variable ``XAI_API_KEY``."
]
},
{
"cell_type": "markdown",
"id": "8304b4d9",
"metadata": {},
"source": [
"## Example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "637bb53f",
"metadata": {},
"outputs": [],
"source": [
"# Querying chat models with xAI\n",
"\n",
"from langchain_xai import ChatXAI\n",
"\n",
"chat = ChatXAI(\n",
" # xai_api_key=\"YOUR_API_KEY\",\n",
" model=\"grok-beta\",\n",
")\n",
"\n",
"# stream the response back from the model\n",
"for m in chat.stream(\"Tell me fun things to do in NYC\"):\n",
" print(m.content, end=\"\", flush=True)\n",
"\n",
"# if you don't want to do streaming, you can use the invoke method\n",
"# chat.invoke(\"Tell me fun things to do in NYC\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -227,6 +227,17 @@ const FEATURE_TABLES = {
"local": false,
"apiLink": "https://python.langchain.com/api_reference/ibm/chat_models/langchain_ibm.chat_models.ChatWatsonx.html"
},
{
"name": "ChatXAI",
"package": "langchain-xai",
"link": "xai",
"structured_output": true,
"tool_calling": true,
"json_mode": false,
"multimodal": false,
"local": false,
"apiLink": "https://python.langchain.com/api_reference/xai/chat_models/langchain_xai.chat_models.ChatXAI.html"
},
],
},
llms: {

View File

@ -24,6 +24,7 @@ DEFAULT_NAMESPACES = [
"langchain_google_vertexai",
"langchain_mistralai",
"langchain_fireworks",
"langchain_xai",
]
# Namespaces for which only deserializing via the SERIALIZABLE_MAPPING is allowed.
# Load by path is not allowed.

21
libs/partners/xai/LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,56 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE=tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/xai --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_xai
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_xai -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -0,0 +1,17 @@
# langchain-xai
This package contains the LangChain integrations for [xAI](https://x.ai/) through their [APIs](https://console.x.ai).
## Installation and Setup
- Install the LangChain partner package
```bash
pip install -U langchain-xai
```
- Get your xAI api key from the [xAI Dashboard](https://console.x.ai) and set it as an environment variable (`XAI_API_KEY`)
## Chat Completions
This package contains the `ChatXAI` class, which is the recommended way to interface with xAI chat models.

View File

@ -0,0 +1,5 @@
"""This package provides the xAI integration for LangChain."""
from langchain_xai.chat_models import ChatXAI
__all__ = ["ChatXAI"]

View File

@ -0,0 +1,355 @@
"""Wrapper around xAI's Chat Completions API."""
from typing import (
Any,
Dict,
List,
Optional,
)
import openai
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.utils import secret_from_env
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
class ChatXAI(BaseChatOpenAI): # type: ignore[override]
r"""ChatXAI chat model.
Setup:
Install ``langchain-xai`` and set environment variable ``XAI_API_KEY``.
.. code-block:: bash
pip install -U langchain-xai
export XAI_API_KEY="your-api-key"
Key init args completion params:
model: str
Name of model to use.
temperature: float
Sampling temperature.
max_tokens: Optional[int]
Max number of tokens to generate.
logprobs: Optional[bool]
Whether to return logprobs.
Key init args client params:
timeout: Union[float, Tuple[float, float], Any, None]
Timeout for requests.
max_retries: int
Max number of retries.
api_key: Optional[str]
xAI API key. If not passed in will be read from env var `XAI_API_KEY`.
Instantiate:
.. code-block:: python
from langchain_xai import ChatXAI
llm = ChatXAI(
model="grok-beta",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
# api_key="...",
# other params...
)
Invoke:
.. code-block:: python
messages = [
(
"system",
"You are a helpful translator. Translate the user sentence to French.",
),
("human", "I love programming."),
]
llm.invoke(messages)
.. code-block:: python
AIMessage(
content="J'adore la programmation.",
response_metadata={
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
'model_name': 'grok-beta',
'system_fingerprint': None,
'finish_reason': 'stop',
'logprobs': None
},
id='run-168dceca-3b8b-4283-94e3-4c739dbc1525-0',
usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41})
Stream:
.. code-block:: python
for chunk in llm.stream(messages):
print(chunk)
.. code-block:: python
content='J' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content="'" id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='ad' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='ore' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content=' la' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content=' programm' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='ation' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='.' id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
content='' response_metadata={'finish_reason': 'stop', 'model_name': 'grok-beta'} id='run-1bc996b5-293f-4114-96a1-e0f755c05eb9'
Async:
.. code-block:: python
await llm.ainvoke(messages)
# stream:
# async for chunk in (await llm.astream(messages))
# batch:
# await llm.abatch([messages])
.. code-block:: python
AIMessage(
content="J'adore la programmation.",
response_metadata={
'token_usage': {'completion_tokens': 9, 'prompt_tokens': 32, 'total_tokens': 41},
'model_name': 'grok-beta',
'system_fingerprint': None,
'finish_reason': 'stop',
'logprobs': None
},
id='run-09371a11-7f72-4c53-8e7c-9de5c238b34c-0',
usage_metadata={'input_tokens': 32, 'output_tokens': 9, 'total_tokens': 41})
Tool calling:
.. code-block:: python
from pydantic import BaseModel, Field
llm = ChatXAI(model="grok-beta")
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke(
"Which city is bigger: LA or NY?"
)
ai_msg.tool_calls
.. code-block:: python
[
{
'name': 'GetPopulation',
'args': {'location': 'NY'},
'id': 'call_m5tstyn2004pre9bfuxvom8x',
'type': 'tool_call'
},
{
'name': 'GetPopulation',
'args': {'location': 'LA'},
'id': 'call_0vjgq455gq1av5sp9eb1pw6a',
'type': 'tool_call'
}
]
Structured output:
.. code-block:: python
from typing import Optional
from pydantic import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
structured_llm = llm.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats")
.. code-block:: python
Joke(
setup='Why was the cat sitting on the computer?',
punchline='To keep an eye on the mouse!',
rating=7
)
Token usage:
.. code-block:: python
ai_msg = llm.invoke(messages)
ai_msg.usage_metadata
.. code-block:: python
{'input_tokens': 37, 'output_tokens': 6, 'total_tokens': 43}
Logprobs:
.. code-block:: python
logprobs_llm = llm.bind(logprobs=True)
messages=[("human","Say Hello World! Do not return anything else.")]
ai_msg = logprobs_llm.invoke(messages)
ai_msg.response_metadata["logprobs"]
.. code-block:: python
{
'content': None,
'token_ids': [22557, 3304, 28808, 2],
'tokens': [' Hello', ' World', '!', '</s>'],
'token_logprobs': [-4.7683716e-06, -5.9604645e-07, 0, -0.057373047]
}
Response metadata
.. code-block:: python
ai_msg = llm.invoke(messages)
ai_msg.response_metadata
.. code-block:: python
{
'token_usage': {
'completion_tokens': 4,
'prompt_tokens': 19,
'total_tokens': 23
},
'model_name': 'grok-beta',
'system_fingerprint': None,
'finish_reason': 'stop',
'logprobs': None
}
""" # noqa: E501
model_name: str = Field(alias="model")
"""Model name to use."""
xai_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("XAI_API_KEY", default=None),
)
"""xAI API key.
Automatically read from env variable `XAI_API_KEY` if not provided.
"""
xai_api_base: str = Field(default="https://api.x.ai/v1/")
"""Base URL path for API requests."""
openai_api_key: Optional[SecretStr] = None
openai_api_base: Optional[str] = None
model_config = ConfigDict(
populate_by_name=True,
)
@property
def lc_secrets(self) -> Dict[str, str]:
"""A map of constructor argument names to secret ids.
For example,
{"xai_api_key": "XAI_API_KEY"}
"""
return {"xai_api_key": "XAI_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain_xai", "chat_models"]
@property
def lc_attributes(self) -> Dict[str, Any]:
"""List of attribute names that should be included in the serialized kwargs.
These attributes must be accepted by the constructor.
"""
attributes: Dict[str, Any] = {}
if self.xai_api_base:
attributes["xai_api_base"] = self.xai_api_base
return attributes
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "xai-chat"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get the parameters used to invoke the model."""
params = super()._get_ls_params(stop=stop, **kwargs)
params["ls_provider"] = "xai"
return params
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
if self.n < 1:
raise ValueError("n must be at least 1.")
if self.n > 1 and self.streaming:
raise ValueError("n must be 1 when streaming.")
client_params: dict = {
"api_key": (
self.xai_api_key.get_secret_value() if self.xai_api_key else None
),
"base_url": self.xai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
if client_params["api_key"] is None:
raise ValueError(
"xAI API key is not set. Please set it in the `xai_api_key` field or "
"in the `XAI_API_KEY` environment variable."
)
if not (self.client or None):
sync_specific: dict = {"http_client": self.http_client}
self.client = openai.OpenAI(
**client_params, **sync_specific
).chat.completions
if not (self.async_client or None):
async_specific: dict = {"http_client": self.http_async_client}
self.async_client = openai.AsyncOpenAI(
**client_params, **async_specific
).chat.completions
return self

View File

2074
libs/partners/xai/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,100 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-xai"
version = "0.1.0"
description = "An integration package connecting xAI and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.mypy]
disallow_untyped_defs = "True"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/xai"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-xai%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-openai = "^0.2"
langchain-core = "^0.3"
requests = "^2"
aiohttp = "^3.9.1"
[tool.ruff.lint]
select = ["E", "F", "I", "D"]
[tool.coverage.run]
omit = ["tests/*"]
[tool.pytest.ini_options]
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.typing]
optional = true
[tool.poetry.group.dev]
optional = true
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["D"]
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
docarray = "^0.32.1"
langchain-openai = { path = "../openai", develop = true }
langchain-core = { path = "../../core", develop = true }
langchain-standard-tests = { path = "../../standard-tests", develop = true }
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration.dependencies]
[[tool.poetry.group.test_integration.dependencies.numpy]]
version = "^1"
python = "<3.12"
[[tool.poetry.group.test_integration.dependencies.numpy]]
version = "^1.26.0"
python = ">=3.12"
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
types-requests = "^2"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }

View File

@ -0,0 +1,19 @@
"""This module checks if the given python files can be imported without error."""
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file)
traceback.print_exc()
print()
sys.exit(1 if has_failure else 0)

View File

@ -0,0 +1,17 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

View File

@ -0,0 +1,52 @@
"""Standard LangChain interface tests"""
from typing import Optional, Type
import pytest # type: ignore[import-not-found]
from langchain_core.language_models import BaseChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_standard_tests.integration_tests import ( # type: ignore[import-not-found]
ChatModelIntegrationTests, # type: ignore[import-not-found]
)
from langchain_xai import ChatXAI
# Initialize the rate limiter in global scope, so it can be re-used
# across tests.
rate_limiter = InMemoryRateLimiter(
requests_per_second=0.5,
)
class TestXAIStandard(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatXAI
@property
def chat_model_params(self) -> dict:
return {
"model": "grok-beta",
"rate_limiter": rate_limiter,
}
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "tool_name"
@pytest.mark.xfail(reason="Not yet supported.")
def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
super().test_usage_metadata_streaming(model)
@pytest.mark.xfail(reason="Can't handle AIMessage with empty content.")
def test_tool_message_error_status(self, model: BaseChatModel) -> None:
super().test_tool_message_error_status(model)
@pytest.mark.xfail(reason="Can't handle AIMessage with empty content.")
def test_structured_few_shot_examples(self, model: BaseChatModel) -> None:
super().test_structured_few_shot_examples(model)
@pytest.mark.xfail(reason="Can't handle AIMessage with empty content.")
def test_tool_message_histories_string_content(self, model: BaseChatModel) -> None:
super().test_tool_message_histories_string_content(model)

View File

@ -0,0 +1,7 @@
import pytest # type: ignore[import-not-found]
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -0,0 +1,3 @@
import os
os.environ["XAI_API_KEY"] = "test"

View File

@ -0,0 +1,31 @@
# serializer version: 1
# name: TestXAIStandard.test_serdes[serialized]
dict({
'id': list([
'langchain_xai',
'chat_models',
'ChatXAI',
]),
'kwargs': dict({
'max_retries': 2,
'max_tokens': 100,
'model_name': 'grok-beta',
'n': 1,
'request_timeout': 60.0,
'stop': list([
]),
'temperature': 0.0,
'xai_api_base': 'https://api.x.ai/v1/',
'xai_api_key': dict({
'id': list([
'XAI_API_KEY',
]),
'lc': 1,
'type': 'secret',
}),
}),
'lc': 1,
'name': 'ChatXAI',
'type': 'constructor',
})
# ---

View File

@ -0,0 +1,129 @@
import json
import pytest # type: ignore[import-not-found]
from langchain_core.messages import (
AIMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_openai.chat_models.base import (
_convert_dict_to_message,
_convert_message_to_dict,
)
from langchain_xai import ChatXAI
def test_initialization() -> None:
"""Test chat model initialization."""
ChatXAI(model="grok-beta")
def test_xai_model_param() -> None:
llm = ChatXAI(model="foo")
assert llm.model_name == "foo"
llm = ChatXAI(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
ls_params = llm._get_ls_params()
assert ls_params["ls_provider"] == "xai"
def test_chat_xai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError):
ChatXAI(
model="grok-beta",
max_tokens=10,
streaming=True,
temperature=0,
n=5,
)
def test_chat_xai_extra_kwargs() -> None:
"""Test extra kwargs to chat xai."""
# Check that foo is saved in extra_kwargs.
llm = ChatXAI(model="grok-beta", foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = ChatXAI(model="grok-beta", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatXAI(model="grok-beta", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
def test_function_dict_to_message_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
"content": content,
}
)
assert isinstance(result, FunctionMessage)
assert result.name == name
assert result.content == content
def test_convert_dict_to_message_human() -> None:
message = {"role": "user", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test__convert_dict_to_message_human_with_name() -> None:
message = {"role": "user", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = HumanMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_ai() -> None:
message = {"role": "assistant", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_ai_with_name() -> None:
message = {"role": "assistant", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = AIMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_system() -> None:
message = {"role": "system", "content": "foo"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_system_with_name() -> None:
message = {"role": "system", "content": "foo", "name": "test"}
result = _convert_dict_to_message(message)
expected_output = SystemMessage(content="foo", name="test")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message
def test_convert_dict_to_message_tool() -> None:
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
result = _convert_dict_to_message(message)
expected_output = ToolMessage(content="foo", tool_call_id="bar")
assert result == expected_output
assert _convert_message_to_dict(expected_output) == message

View File

@ -0,0 +1,35 @@
"""Standard LangChain interface tests"""
from typing import Tuple, Type
from langchain_core.language_models import BaseChatModel
from langchain_standard_tests.unit_tests import ( # type: ignore[import-not-found]
ChatModelUnitTests, # type: ignore[import-not-found]
)
from langchain_xai import ChatXAI
class TestXAIStandard(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatXAI
@property
def chat_model_params(self) -> dict:
return {"model": "grok-beta"}
@property
def init_from_env_params(self) -> Tuple[dict, dict, dict]:
return (
{
"XAI_API_KEY": "api_key",
},
{
"model": "grok-beta",
},
{
"xai_api_key": "api_key",
"xai_api_base": "https://api.x.ai/v1/",
},
)

View File

@ -0,0 +1,7 @@
from langchain_xai import __all__
EXPECTED_ALL = ["ChatXAI"]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -0,0 +1,7 @@
from langchain_xai import ChatXAI
def test_chat_xai_secrets() -> None:
o = ChatXAI(model="grok-beta", xai_api_key="foo") # type: ignore[call-arg]
s = str(o)
assert "foo" not in s