AI21: tools calling support in Langchain (#25635)

This pull request introduces support for the AI21 tools calling feature,
available by the Jamba-1.5 models. When Jamba-1.5 detects the necessity
to invoke a provided tool, as indicated by the 'tools' parameter passed
to the model:

```
class ToolDefinition(TypedDict, total=False):
    type: Required[Literal["function"]]
    function: Required[FunctionToolDefinition]

class FunctionToolDefinition(TypedDict, total=False):
    name: Required[str]
    description: str
    parameters: ToolParameters

class ToolParameters(TypedDict, total=False):
    type: Literal["object"]
    properties: Required[Dict[str, Any]]
    required: List[str]
```

It will respond with a list of tool calls structured as follows:

```
class ToolCall(AI21BaseModel):
    id: str
    function: ToolFunction
    type: Literal["function"] = "function"

class ToolFunction(AI21BaseModel):
    name: str
    arguments: str
```

This pull request incorporates the necessary modifications to integrate
this functionality into the ai21-langchain library.

---------

Co-authored-by: asafg <asafg@ai21.com>
Co-authored-by: pazshalev <111360591+pazshalev@users.noreply.github.com>
Co-authored-by: Paz Shalev <pazs@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
amirai21 2024-08-26 20:50:30 +03:00 committed by GitHub
parent a566a15930
commit 17dffd9741
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 397 additions and 119 deletions

View File

@ -150,4 +150,68 @@ from langchain_ai21 import AI21SemanticTextSplitter
splitter = AI21SemanticTextSplitter()
response = splitter.split_text("Your text")
```
## Tool calls
### Function calling
AI21 models incorporate the Function Calling feature to support custom user functions. The models generate structured
data that includes the function name and proposed arguments. This data empowers applications to call external APIs and
incorporate the resulting information into subsequent model prompts, enriching responses with real-time data and
context. Through function calling, users can access and utilize various services like transportation APIs and financial
data providers to obtain more accurate and relevant answers. Here is an example of how to use function calling
with AI21 models in LangChain:
```python
import os
from getpass import getpass
from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
from langchain_core.tools import tool
from langchain_ai21.chat_models import ChatAI21
from langchain_core.utils.function_calling import convert_to_openai_tool
os.environ["AI21_API_KEY"] = getpass()
@tool
def get_weather(location: str, date: str) -> str:
"""“Provide the weather for the specified location on the given date.”"""
if location == "New York" and date == "2024-12-05":
return "25 celsius"
elif location == "New York" and date == "2024-12-06":
return "27 celsius"
elif location == "London" and date == "2024-12-05":
return "22 celsius"
return "32 celsius"
llm = ChatAI21(model="jamba-1.5-mini")
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
chat_messages = [SystemMessage(content="You are a helpful assistant. You can use the provided tools "
"to assist with various tasks and provide accurate information")]
human_messages = [
HumanMessage(content="What is the forecast for the weather in New York on December 5, 2024?"),
HumanMessage(content="And what about the 2024-12-06?"),
HumanMessage(content="OK, thank you."),
HumanMessage(content="What is the expected weather in London on December 5, 2024?")]
for human_message in human_messages:
print(f"User: {human_message.content}")
chat_messages.append(human_message)
response = llm_with_tools.invoke(chat_messages)
chat_messages.append(response)
if response.tool_calls:
tool_call = response.tool_calls[0]
if tool_call["name"] == "get_weather":
weather = get_weather.invoke(
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]})
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
llm_answer = llm_with_tools.invoke(chat_messages)
print(f"Assistant: {llm_answer.content}")
else:
print(f"Assistant: {response.content}")
```

View File

@ -1,11 +1,20 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload
from typing import Any, Dict, Iterator, List, Literal, Optional, Union, cast, overload
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import ChatCompletionChunk, ChatMessage
from ai21.models.chat import (
AssistantMessage as AI21AssistantMessage,
)
from ai21.models.chat import ChatCompletionChunk, ChatMessageParam
from ai21.models.chat import ChatMessage as AI21ChatMessage
from ai21.models.chat import SystemMessage as AI21SystemMessage
from ai21.models.chat import ToolCall as AI21ToolCall
from ai21.models.chat import ToolFunction as AI21ToolFunction
from ai21.models.chat import ToolMessage as AI21ToolMessage
from ai21.models.chat import UserMessage as AI21UserMessage
from ai21.stream.stream import Stream as AI21Stream
from langchain_core.messages import (
AIMessage,
@ -13,11 +22,15 @@ from langchain_core.messages import (
BaseMessage,
BaseMessageChunk,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers.openai_tools import parse_tool_call
from langchain_core.outputs import ChatGenerationChunk
_ChatMessageTypes = Union[ChatMessage, J2ChatMessage]
_ChatMessageTypes = Union[AI21ChatMessage, J2ChatMessage]
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
_ROLE_TYPE = Union[str, RoleType]
@ -40,20 +53,24 @@ class ChatAdapter(ABC):
self,
message: BaseMessage,
) -> _ChatMessageTypes:
content = cast(str, message.content)
role = self._parse_role(message)
return self._chat_message(role=role, content=content)
return self._chat_message(role=role, message=message)
def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE:
role = None
if isinstance(message, SystemMessage):
return RoleType.SYSTEM
if isinstance(message, HumanMessage):
return RoleType.USER
if isinstance(message, AIMessage):
return RoleType.ASSISTANT
if isinstance(message, ToolMessage):
return RoleType.TOOL
if isinstance(self, J2ChatAdapter):
if not role:
raise ValueError(
@ -68,7 +85,7 @@ class ChatAdapter(ABC):
def _chat_message(
self,
role: _ROLE_TYPE,
content: str,
message: BaseMessage,
) -> _ChatMessageTypes:
pass
@ -130,9 +147,9 @@ class J2ChatAdapter(ChatAdapter):
def _chat_message(
self,
role: _ROLE_TYPE,
content: str,
message: BaseMessage,
) -> J2ChatMessage:
return J2ChatMessage(role=RoleType(role), text=content)
return J2ChatMessage(role=RoleType(role), text=cast(str, message.content))
@overload
def call(
@ -174,12 +191,65 @@ class JambaChatCompletionsAdapter(ChatAdapter):
],
}
def _convert_lc_tool_calls_to_ai21_tool_calls(
self, tool_calls: List[ToolCall]
) -> Optional[List[AI21ToolCall]]:
"""
Convert Langchain ToolCalls to AI21 ToolCalls.
"""
ai21_tool_calls: List[AI21ToolCall] = []
for lc_tool_call in tool_calls:
if "id" not in lc_tool_call or not lc_tool_call["id"]:
raise ValueError("Tool call ID is missing or empty.")
ai21_tool_call = AI21ToolCall(
id=lc_tool_call["id"],
type="function",
function=AI21ToolFunction(
name=lc_tool_call["name"],
arguments=str(lc_tool_call["args"]),
),
)
ai21_tool_calls.append(ai21_tool_call)
return ai21_tool_calls
def _get_content_as_string(self, base_message: BaseMessage) -> str:
if isinstance(base_message.content, str):
return base_message.content
elif isinstance(base_message.content, list):
return "\n".join(str(item) for item in base_message.content)
else:
raise ValueError("Unsupported content type")
def _chat_message(
self,
role: _ROLE_TYPE,
content: str,
) -> ChatMessage:
return ChatMessage(
message: BaseMessage,
) -> ChatMessageParam:
content = self._get_content_as_string(message)
if isinstance(message, AIMessage):
return AI21AssistantMessage(
tool_calls=self._convert_lc_tool_calls_to_ai21_tool_calls(
message.tool_calls
),
content=content or None,
)
if isinstance(message, ToolMessage):
return AI21ToolMessage(
tool_call_id=message.tool_call_id,
content=content,
)
if isinstance(message, HumanMessage):
return AI21UserMessage(
content=content,
)
if isinstance(message, SystemMessage):
return AI21SystemMessage(
content=content,
)
return AI21ChatMessage(
role=role.value if isinstance(role, RoleType) else role,
content=content,
)
@ -211,7 +281,18 @@ class JambaChatCompletionsAdapter(ChatAdapter):
if stream:
return self._stream_response(response)
return [AIMessage(choice.message.content) for choice in response.choices]
ai_messages: List[BaseMessage] = []
for message in response.choices:
if message.message.tool_calls:
tool_calls = [
parse_tool_call(tool_call.model_dump(), return_id=True)
for tool_call in message.message.tool_calls
]
ai_messages.append(AIMessage("", tool_calls=tool_calls))
else:
ai_messages.append(AIMessage(message.message.content))
return ai_messages
def _stream_response(
self,

View File

@ -1,11 +1,23 @@
import asyncio
from functools import partial
from typing import Any, Dict, Iterator, List, Mapping, Optional
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
@ -16,6 +28,9 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_ai21.ai21_base import AI21Base
from langchain_ai21.chat.chat_adapter import ChatAdapter
@ -48,14 +63,14 @@ class ChatAI21(BaseChatModel, AI21Base):
stop: Optional[List[str]] = None
"""Default stop sequences."""
max_tokens: int = 16
max_tokens: int = 512
"""The maximum number of tokens to generate for each response."""
min_tokens: int = 0
"""The minimum number of tokens to generate for each response.
_Not supported for all models._"""
temperature: float = 0.7
temperature: float = 0.4
"""A value controlling the "creativity" of the model's responses."""
top_p: float = 1
@ -246,3 +261,11 @@ class ChatAI21(BaseChatModel, AI21Base):
)
return message.content
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

View File

@ -1,35 +1,36 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "ai21"
version = "2.7.0"
version = "2.14.1"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
files = [
{file = "ai21-2.7.0-py3-none-any.whl", hash = "sha256:9060aa90f0acc21ce1e3ad90c814762ba0914dd5af073c269868dbcdf5ecd108"},
{file = "ai21-2.7.0.tar.gz", hash = "sha256:3f86f47af67fa43b086773aa01d89286ec2011dbc1a4a53aaca3a104ac1f958f"},
{file = "ai21-2.14.1-py3-none-any.whl", hash = "sha256:618c0b5c025123c703645258472330a07ae2de17020438f6a33b29668275995c"},
{file = "ai21-2.14.1.tar.gz", hash = "sha256:05d9626b82206e0a5be43d17c39d4e0c7b51c2b7634d2ea38a2c698ac3e2fd5b"},
]
[package.dependencies]
ai21-tokenizer = ">=0.11.0,<1.0.0"
dataclasses-json = ">=0.6.3,<0.7.0"
ai21-tokenizer = ">=0.12.0,<1.0.0"
httpx = ">=0.27.0,<0.28.0"
pydantic = ">=1.9.0,<3.0.0"
tenacity = ">=8.3.0,<9.0.0"
typing-extensions = ">=4.9.0,<5.0.0"
[package.extras]
aws = ["boto3 (>=1.28.82,<2.0.0)"]
vertex = ["google-auth (>=2.31.0,<3.0.0)"]
[[package]]
name = "ai21-tokenizer"
version = "0.11.2"
version = "0.12.0"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
files = [
{file = "ai21_tokenizer-0.11.2-py3-none-any.whl", hash = "sha256:a9444ca44ef2bffec7cb9f0c3cfa5501dc973cdde0b740e43e137ce9a2f90eab"},
{file = "ai21_tokenizer-0.11.2.tar.gz", hash = "sha256:35579bca375f071ae6365456f02bd5c9445f408723f7b87646a2bdaa3f57925e"},
{file = "ai21_tokenizer-0.12.0-py3-none-any.whl", hash = "sha256:7fd37b9093894b30b0f200e5f44fc8fb8772e2b272ef71b6d73722b4696e63c4"},
{file = "ai21_tokenizer-0.12.0.tar.gz", hash = "sha256:d2a5b17789d21572504b7693148bf66e692bdb3ab563023dbcbee340bcbd11c6"},
]
[package.dependencies]
@ -211,21 +212,6 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "dataclasses-json"
version = "0.6.6"
description = "Easily serialize dataclasses to and from JSON."
optional = false
python-versions = "<4.0,>=3.7"
files = [
{file = "dataclasses_json-0.6.6-py3-none-any.whl", hash = "sha256:e54c5c87497741ad454070ba0ed411523d46beb5da102e221efb873801b0ba85"},
{file = "dataclasses_json-0.6.6.tar.gz", hash = "sha256:0c09827d26fffda27f1be2fed7a7a01a29c5ddcd2eb6393ad5ebf9d77e9deae8"},
]
[package.dependencies]
marshmallow = ">=3.18.0,<4.0.0"
typing-inspect = ">=0.4.0,<1"
[[package]]
name = "exceptiongroup"
version = "1.2.1"
@ -448,7 +434,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.11"
version = "0.2.33"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -465,6 +451,7 @@ pydantic = [
]
PyYAML = ">=5.3"
tenacity = "^8.1.0,!=8.4.0"
typing-extensions = ">=4.7"
[package.source]
type = "directory"
@ -490,7 +477,7 @@ url = "../../standard-tests"
[[package]]
name = "langchain-text-splitters"
version = "0.2.2"
version = "0.2.3"
description = "LangChain text splitting utilities"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -523,25 +510,6 @@ pydantic = [
]
requests = ">=2,<3"
[[package]]
name = "marshmallow"
version = "3.21.2"
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
optional = false
python-versions = ">=3.8"
files = [
{file = "marshmallow-3.21.2-py3-none-any.whl", hash = "sha256:70b54a6282f4704d12c0a41599682c5c5450e843b9ec406308653b47c59648a1"},
{file = "marshmallow-3.21.2.tar.gz", hash = "sha256:82408deadd8b33d56338d2182d455db632c6313aa2af61916672146bb32edc56"},
]
[package.dependencies]
packaging = ">=17.0"
[package.extras]
dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"]
docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"]
tests = ["pytest", "pytz", "simplejson"]
[[package]]
name = "mypy"
version = "1.10.1"
@ -1257,21 +1225,6 @@ files = [
{file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"},
]
[[package]]
name = "typing-inspect"
version = "0.9.0"
description = "Runtime inspection utilities for typing module."
optional = false
python-versions = "*"
files = [
{file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"},
{file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"},
]
[package.dependencies]
mypy-extensions = ">=0.3.0"
typing-extensions = ">=3.7.4"
[[package]]
name = "urllib3"
version = "2.2.1"
@ -1336,4 +1289,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "9dee8f52fd10c8ffe640c33620d7a33daa989ba38f56f7a33a2e98a734457015"
content-hash = "32e1777e151eef2eb3775c4f1707ec4e80241a08fd4beed2d60a36f58d68dfea"

View File

@ -22,7 +22,7 @@ disallow_untyped_defs = "True"
python = ">=3.8.1,<4.0"
langchain-core = "^0.2.4"
langchain-text-splitters = "^0.2.0"
ai21 = "^2.7.0"
ai21 = "^2.14.1"
[tool.ruff.lint]
select = [ "E", "F", "I",]

View File

@ -1,12 +1,25 @@
"""Test ChatAI21 chat model."""
import pytest
from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.messages import (
AIMessageChunk,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_core.tools import tool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_ai21.chat_models import ChatAI21
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
from tests.unit_tests.conftest import (
J2_CHAT_MODEL_NAME,
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
JAMBA_CHAT_MODEL_NAME,
JAMBA_FAMILY_MODEL_NAMES,
)
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
@ -15,11 +28,15 @@ rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
ids=[
"when_j2_model",
"when_jamba_model",
"when_jamba1.5-mini_model",
"when_jamba1.5-large_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
],
)
def test_invoke(model: str) -> None:
@ -36,6 +53,10 @@ def test_invoke(model: str) -> None:
"when_j2_model_num_results_is_3",
"when_jamba_model_n_is_1",
"when_jamba_model_n_is_3",
"when_jamba1.5_mini_model_n_is_1",
"when_jamba1.5_mini_model_n_is_3",
"when_jamba1.5_large_model_n_is_1",
"when_jamba1.5_large_model_n_is_3",
],
argnames=["model", "num_results"],
argvalues=[
@ -43,12 +64,16 @@ def test_invoke(model: str) -> None:
(J2_CHAT_MODEL_NAME, 3),
(JAMBA_CHAT_MODEL_NAME, 1),
(JAMBA_CHAT_MODEL_NAME, 3),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 1),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 3),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 1),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 3),
],
)
def test_generation(model: str, num_results: int) -> None:
"""Test generation with multiple models and different result counts."""
# Determine the configuration key based on the model type
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
config_key = "n" if model in JAMBA_FAMILY_MODEL_NAMES else "num_results"
# Create the model instance using the appropriate key for the result count
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
@ -69,11 +94,15 @@ def test_generation(model: str, num_results: int) -> None:
ids=[
"when_j2_model",
"when_jamba_model",
"when_jamba1.5_mini_model",
"when_jamba1.5_large_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
],
)
async def test_ageneration(model: str) -> None:
@ -92,7 +121,7 @@ async def test_ageneration(model: str) -> None:
def test__chat_stream() -> None:
llm = ChatAI21(model="jamba-instruct") # type: ignore[call-arg]
llm = ChatAI21(model="jamba-1.5-mini") # type: ignore[call-arg]
message = HumanMessage(content="What is the meaning of life?")
for chunk in llm.stream([message]):
@ -107,3 +136,53 @@ def test__j2_chat_stream__should_raise_error() -> None:
with pytest.raises(NotImplementedError):
for _ in llm.stream([message]):
pass
@pytest.mark.parametrize(
ids=[
"when_jamba1.5_mini_model",
"when_jamba1.5_large_model",
],
argnames=["model"],
argvalues=[
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
],
)
def test_tool_calls(model: str) -> None:
@tool
def get_weather(location: str, date: str) -> str:
"""“Provide the weather for the specified location on the given date.”"""
if location == "New York" and date == "2024-12-05":
return "25 celsius"
return "32 celsius"
llm = ChatAI21(model=model, temperature=0) # type: ignore[call-arg]
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
chat_messages = [
SystemMessage(
content="You are a helpful assistant. "
"You can use the provided tools "
"to assist with various tasks and provide "
"accurate information"
),
HumanMessage(
content="What is the forecast for the weather "
"in New York on December 5, 2024?"
),
]
response = llm_with_tools.invoke(chat_messages)
chat_messages.append(response)
assert response.tool_calls is not None # type: ignore[attr-defined]
tool_call = response.tool_calls[0] # type: ignore[attr-defined]
assert tool_call["name"] == "get_weather"
weather = get_weather.invoke( # type: ignore[attr-defined]
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]}
)
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
llm_answer = llm_with_tools.invoke(chat_messages)
content = llm_answer.content.lower() # type: ignore[union-attr]
assert "new york" in content and "25" in content and "celsius" in content

View File

@ -1,7 +1,7 @@
"""Standard LangChain interface tests"""
import time
from typing import Type
from typing import Optional, Type
import pytest
from langchain_core.language_models import BaseChatModel
@ -28,6 +28,8 @@ class BaseTestAI21(ChatModelIntegrationTests):
class TestAI21J2(BaseTestAI21):
has_tool_calling = False
@property
def chat_model_params(self) -> dict:
return {
@ -49,8 +51,23 @@ class TestAI21J2(BaseTestAI21):
class TestAI21Jamba(BaseTestAI21):
has_tool_calling = False
@property
def chat_model_params(self) -> dict:
return {
"model": "jamba-instruct-preview",
}
class TestAI21Jamba1_5(BaseTestAI21):
@property
def tool_choice_value(self) -> Optional[str]:
"""Value to use for tool choice when used in tests."""
return "any"
@property
def chat_model_params(self) -> dict:
return {
"model": "jamba-1.5-mini",
}

View File

@ -3,12 +3,21 @@ from typing import List
import pytest
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import ChatMessage
from ai21.models.chat import (
AssistantMessage,
ChatMessage,
UserMessage,
)
from ai21.models.chat import (
SystemMessage as AI21SystemMessage,
)
from ai21.models.chat import ToolMessage as AI21ToolMessage
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.messages import (
ChatMessage as LangChainChatMessage,
@ -18,6 +27,8 @@ from langchain_ai21.chat.chat_adapter import ChatAdapter
_J2_MODEL_NAME = "j2-ultra"
_JAMBA_MODEL_NAME = "jamba-instruct-preview"
_JAMBA_1_5_MINI_MODEL_NAME = "jamba-1.5-mini"
_JAMBA_1_5_LARGE_MODEL_NAME = "jamba-1.5-large"
@pytest.mark.parametrize(
@ -42,12 +53,14 @@ _JAMBA_MODEL_NAME = "jamba-instruct-preview"
(
_JAMBA_MODEL_NAME,
HumanMessage(content="Human Message Content"),
ChatMessage(role=RoleType.USER, content="Human Message Content"),
UserMessage(role="user", content="Human Message Content"),
),
(
_JAMBA_MODEL_NAME,
AIMessage(content="AI Message Content"),
ChatMessage(role=RoleType.ASSISTANT, content="AI Message Content"),
AssistantMessage(
role="assistant", content="AI Message Content", tool_calls=[]
),
),
],
)
@ -69,7 +82,7 @@ def test_convert_message_to_ai21_message(
argvalues=[
(
_J2_MODEL_NAME,
SystemMessage(content="System Message Content"),
AI21SystemMessage(content="System Message Content"),
),
(
_J2_MODEL_NAME,
@ -95,6 +108,8 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
"when_first_message_is_system__should_return_system_j2_model",
"when_all_messages_are_human_messages__should_return_system_none_jamba_model",
"when_first_message_is_system__should_return_system_jamba_model",
"when_tool_calling_message__should_return_tool_jamba_mini_model",
"when_tool_calling_message__should_return_tool_jamba_large_model",
],
argnames=["model", "messages", "expected_messages"],
argvalues=[
@ -142,12 +157,12 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
],
{
"messages": [
ChatMessage(
role=RoleType.USER,
UserMessage(
role="user",
content="Human Message Content 1",
),
ChatMessage(
role=RoleType.USER,
UserMessage(
role="user",
content="Human Message Content 2",
),
]
@ -161,8 +176,46 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
],
{
"messages": [
ChatMessage(role="system", content="System Message Content 1"),
ChatMessage(role="user", content="Human Message Content 1"),
AI21SystemMessage(
role="system", content="System Message Content 1"
),
UserMessage(role="user", content="Human Message Content 1"),
],
},
),
(
_JAMBA_1_5_MINI_MODEL_NAME,
[
ToolMessage(
content="42",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
)
],
{
"messages": [
AI21ToolMessage(
role="tool",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
content="42",
),
],
},
),
(
_JAMBA_1_5_LARGE_MODEL_NAME,
[
ToolMessage(
content="42",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
)
],
{
"messages": [
AI21ToolMessage(
role="tool",
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
content="42",
),
],
},
),

View File

@ -23,8 +23,16 @@ from pytest_mock import MockerFixture
J2_CHAT_MODEL_NAME = "j2-ultra"
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
JAMBA_1_5_MINI_CHAT_MODEL_NAME = "jamba-1.5-mini"
JAMBA_1_5_LARGE_CHAT_MODEL_NAME = "jamba-1.5-large"
DUMMY_API_KEY = "test_api_key"
JAMBA_FAMILY_MODEL_NAMES = [
JAMBA_CHAT_MODEL_NAME,
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
]
BASIC_EXAMPLE_LLM_PARAMETERS = {
"num_results": 3,
"max_tokens": 20,
@ -32,9 +40,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
"count_penalty": Penalty(
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
@ -48,9 +56,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = {
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
"count_penalty": Penalty(
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
@ -59,7 +67,7 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = {
}
SEGMENTS = [
Segment(
Segment( # type: ignore[call-arg]
segment_type="normal_text",
segment_text=(
"The original full name of the franchise is Pocket Monsters "
@ -70,7 +78,7 @@ SEGMENTS = [
"in pronunciation."
),
),
Segment(
Segment( # type: ignore[call-arg]
segment_type="normal_text",
segment_text=(
"Pokémon refers to both the franchise itself and the creatures "
@ -92,9 +100,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
"count_penalty": Penalty(
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
@ -108,9 +116,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
"count_penalty": Penalty(
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
"count_penalty": Penalty( # type: ignore[call-arg]
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
@ -124,7 +132,7 @@ def mocked_completion_response(mocker: MockerFixture) -> Mock:
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
mocked_response.prompt = "this is a test prompt"
mocked_response.completions = [
Completion(
Completion( # type: ignore[call-arg]
data=CompletionData(text="test", tokens=[]),
finish_reason=CompletionFinishReason(reason=None, length=None),
)
@ -152,7 +160,7 @@ def mock_client_with_chat(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.chat = mocker.MagicMock()
output = ChatOutput(
output = ChatOutput( # type: ignore[call-arg]
text="Hello Pickle Rick!",
role=RoleType.ASSISTANT,
finish_reason=FinishReason(reason="testing"),
@ -178,7 +186,7 @@ def temporarily_unset_api_key() -> Generator:
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.answer = mocker.MagicMock()
mock_client.answer.create.return_value = AnswerResponse(
mock_client.answer.create.return_value = AnswerResponse( # type: ignore[call-arg]
id="some_id",
answer="some answer",
answer_in_context=False,

View File

@ -45,9 +45,9 @@ def test_initialization__when_custom_parameters_in_init() -> None:
temperature = 0.1
top_p = 0.1
top_k_return = 0
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True)
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True)
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True)
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) # type: ignore[call-arg]
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) # type: ignore[call-arg]
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) # type: ignore[call-arg]
llm = ChatAI21( # type: ignore[call-arg]
api_key=DUMMY_API_KEY, # type: ignore[arg-type]

View File

@ -17,9 +17,9 @@ _EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
results=[
EmbedResult(_EXAMPLE_EMBEDDING_0),
EmbedResult(_EXAMPLE_EMBEDDING_1),
EmbedResult(_EXAMPLE_EMBEDDING_2),
EmbedResult(embedding=_EXAMPLE_EMBEDDING_0),
EmbedResult(embedding=_EXAMPLE_EMBEDDING_1),
EmbedResult(embedding=_EXAMPLE_EMBEDDING_2),
],
id="test_id",
)

View File

@ -49,9 +49,9 @@ def test_initialization__when_custom_parameters_to_init() -> None:
top_p=0.5,
top_k_return=0,
stop_sequences=["\n"],
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True),
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True),
count_penalty=Penalty(
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
count_penalty=Penalty( # type: ignore[call-arg]
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
),
custom_model="test_model",