mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
AI21: tools calling support in Langchain (#25635)
This pull request introduces support for the AI21 tools calling feature, available by the Jamba-1.5 models. When Jamba-1.5 detects the necessity to invoke a provided tool, as indicated by the 'tools' parameter passed to the model: ``` class ToolDefinition(TypedDict, total=False): type: Required[Literal["function"]] function: Required[FunctionToolDefinition] class FunctionToolDefinition(TypedDict, total=False): name: Required[str] description: str parameters: ToolParameters class ToolParameters(TypedDict, total=False): type: Literal["object"] properties: Required[Dict[str, Any]] required: List[str] ``` It will respond with a list of tool calls structured as follows: ``` class ToolCall(AI21BaseModel): id: str function: ToolFunction type: Literal["function"] = "function" class ToolFunction(AI21BaseModel): name: str arguments: str ``` This pull request incorporates the necessary modifications to integrate this functionality into the ai21-langchain library. --------- Co-authored-by: asafg <asafg@ai21.com> Co-authored-by: pazshalev <111360591+pazshalev@users.noreply.github.com> Co-authored-by: Paz Shalev <pazs@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
a566a15930
commit
17dffd9741
@ -150,4 +150,68 @@ from langchain_ai21 import AI21SemanticTextSplitter
|
||||
|
||||
splitter = AI21SemanticTextSplitter()
|
||||
response = splitter.split_text("Your text")
|
||||
```
|
||||
|
||||
## Tool calls
|
||||
|
||||
### Function calling
|
||||
|
||||
AI21 models incorporate the Function Calling feature to support custom user functions. The models generate structured
|
||||
data that includes the function name and proposed arguments. This data empowers applications to call external APIs and
|
||||
incorporate the resulting information into subsequent model prompts, enriching responses with real-time data and
|
||||
context. Through function calling, users can access and utilize various services like transportation APIs and financial
|
||||
data providers to obtain more accurate and relevant answers. Here is an example of how to use function calling
|
||||
with AI21 models in LangChain:
|
||||
|
||||
```python
|
||||
import os
|
||||
from getpass import getpass
|
||||
from langchain_core.messages import HumanMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
os.environ["AI21_API_KEY"] = getpass()
|
||||
|
||||
@tool
|
||||
def get_weather(location: str, date: str) -> str:
|
||||
"""“Provide the weather for the specified location on the given date.”"""
|
||||
if location == "New York" and date == "2024-12-05":
|
||||
return "25 celsius"
|
||||
elif location == "New York" and date == "2024-12-06":
|
||||
return "27 celsius"
|
||||
elif location == "London" and date == "2024-12-05":
|
||||
return "22 celsius"
|
||||
return "32 celsius"
|
||||
|
||||
llm = ChatAI21(model="jamba-1.5-mini")
|
||||
|
||||
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
|
||||
|
||||
chat_messages = [SystemMessage(content="You are a helpful assistant. You can use the provided tools "
|
||||
"to assist with various tasks and provide accurate information")]
|
||||
|
||||
human_messages = [
|
||||
HumanMessage(content="What is the forecast for the weather in New York on December 5, 2024?"),
|
||||
HumanMessage(content="And what about the 2024-12-06?"),
|
||||
HumanMessage(content="OK, thank you."),
|
||||
HumanMessage(content="What is the expected weather in London on December 5, 2024?")]
|
||||
|
||||
|
||||
for human_message in human_messages:
|
||||
print(f"User: {human_message.content}")
|
||||
chat_messages.append(human_message)
|
||||
response = llm_with_tools.invoke(chat_messages)
|
||||
chat_messages.append(response)
|
||||
if response.tool_calls:
|
||||
tool_call = response.tool_calls[0]
|
||||
if tool_call["name"] == "get_weather":
|
||||
weather = get_weather.invoke(
|
||||
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]})
|
||||
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
|
||||
llm_answer = llm_with_tools.invoke(chat_messages)
|
||||
print(f"Assistant: {llm_answer.content}")
|
||||
else:
|
||||
print(f"Assistant: {response.content}")
|
||||
|
||||
```
|
@ -1,11 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterator, List, Literal, Union, cast, overload
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Union, cast, overload
|
||||
|
||||
from ai21.models import ChatMessage as J2ChatMessage
|
||||
from ai21.models import RoleType
|
||||
from ai21.models.chat import ChatCompletionChunk, ChatMessage
|
||||
from ai21.models.chat import (
|
||||
AssistantMessage as AI21AssistantMessage,
|
||||
)
|
||||
from ai21.models.chat import ChatCompletionChunk, ChatMessageParam
|
||||
from ai21.models.chat import ChatMessage as AI21ChatMessage
|
||||
from ai21.models.chat import SystemMessage as AI21SystemMessage
|
||||
from ai21.models.chat import ToolCall as AI21ToolCall
|
||||
from ai21.models.chat import ToolFunction as AI21ToolFunction
|
||||
from ai21.models.chat import ToolMessage as AI21ToolMessage
|
||||
from ai21.models.chat import UserMessage as AI21UserMessage
|
||||
from ai21.stream.stream import Stream as AI21Stream
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -13,11 +22,15 @@ from langchain_core.messages import (
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.ai import UsageMetadata
|
||||
from langchain_core.output_parsers.openai_tools import parse_tool_call
|
||||
from langchain_core.outputs import ChatGenerationChunk
|
||||
|
||||
_ChatMessageTypes = Union[ChatMessage, J2ChatMessage]
|
||||
_ChatMessageTypes = Union[AI21ChatMessage, J2ChatMessage]
|
||||
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
|
||||
_ROLE_TYPE = Union[str, RoleType]
|
||||
|
||||
@ -40,20 +53,24 @@ class ChatAdapter(ABC):
|
||||
self,
|
||||
message: BaseMessage,
|
||||
) -> _ChatMessageTypes:
|
||||
content = cast(str, message.content)
|
||||
role = self._parse_role(message)
|
||||
|
||||
return self._chat_message(role=role, content=content)
|
||||
return self._chat_message(role=role, message=message)
|
||||
|
||||
def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE:
|
||||
role = None
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
return RoleType.SYSTEM
|
||||
|
||||
if isinstance(message, HumanMessage):
|
||||
return RoleType.USER
|
||||
|
||||
if isinstance(message, AIMessage):
|
||||
return RoleType.ASSISTANT
|
||||
|
||||
if isinstance(message, ToolMessage):
|
||||
return RoleType.TOOL
|
||||
|
||||
if isinstance(self, J2ChatAdapter):
|
||||
if not role:
|
||||
raise ValueError(
|
||||
@ -68,7 +85,7 @@ class ChatAdapter(ABC):
|
||||
def _chat_message(
|
||||
self,
|
||||
role: _ROLE_TYPE,
|
||||
content: str,
|
||||
message: BaseMessage,
|
||||
) -> _ChatMessageTypes:
|
||||
pass
|
||||
|
||||
@ -130,9 +147,9 @@ class J2ChatAdapter(ChatAdapter):
|
||||
def _chat_message(
|
||||
self,
|
||||
role: _ROLE_TYPE,
|
||||
content: str,
|
||||
message: BaseMessage,
|
||||
) -> J2ChatMessage:
|
||||
return J2ChatMessage(role=RoleType(role), text=content)
|
||||
return J2ChatMessage(role=RoleType(role), text=cast(str, message.content))
|
||||
|
||||
@overload
|
||||
def call(
|
||||
@ -174,12 +191,65 @@ class JambaChatCompletionsAdapter(ChatAdapter):
|
||||
],
|
||||
}
|
||||
|
||||
def _convert_lc_tool_calls_to_ai21_tool_calls(
|
||||
self, tool_calls: List[ToolCall]
|
||||
) -> Optional[List[AI21ToolCall]]:
|
||||
"""
|
||||
Convert Langchain ToolCalls to AI21 ToolCalls.
|
||||
"""
|
||||
ai21_tool_calls: List[AI21ToolCall] = []
|
||||
for lc_tool_call in tool_calls:
|
||||
if "id" not in lc_tool_call or not lc_tool_call["id"]:
|
||||
raise ValueError("Tool call ID is missing or empty.")
|
||||
|
||||
ai21_tool_call = AI21ToolCall(
|
||||
id=lc_tool_call["id"],
|
||||
type="function",
|
||||
function=AI21ToolFunction(
|
||||
name=lc_tool_call["name"],
|
||||
arguments=str(lc_tool_call["args"]),
|
||||
),
|
||||
)
|
||||
ai21_tool_calls.append(ai21_tool_call)
|
||||
|
||||
return ai21_tool_calls
|
||||
|
||||
def _get_content_as_string(self, base_message: BaseMessage) -> str:
|
||||
if isinstance(base_message.content, str):
|
||||
return base_message.content
|
||||
elif isinstance(base_message.content, list):
|
||||
return "\n".join(str(item) for item in base_message.content)
|
||||
else:
|
||||
raise ValueError("Unsupported content type")
|
||||
|
||||
def _chat_message(
|
||||
self,
|
||||
role: _ROLE_TYPE,
|
||||
content: str,
|
||||
) -> ChatMessage:
|
||||
return ChatMessage(
|
||||
message: BaseMessage,
|
||||
) -> ChatMessageParam:
|
||||
content = self._get_content_as_string(message)
|
||||
|
||||
if isinstance(message, AIMessage):
|
||||
return AI21AssistantMessage(
|
||||
tool_calls=self._convert_lc_tool_calls_to_ai21_tool_calls(
|
||||
message.tool_calls
|
||||
),
|
||||
content=content or None,
|
||||
)
|
||||
if isinstance(message, ToolMessage):
|
||||
return AI21ToolMessage(
|
||||
tool_call_id=message.tool_call_id,
|
||||
content=content,
|
||||
)
|
||||
if isinstance(message, HumanMessage):
|
||||
return AI21UserMessage(
|
||||
content=content,
|
||||
)
|
||||
if isinstance(message, SystemMessage):
|
||||
return AI21SystemMessage(
|
||||
content=content,
|
||||
)
|
||||
return AI21ChatMessage(
|
||||
role=role.value if isinstance(role, RoleType) else role,
|
||||
content=content,
|
||||
)
|
||||
@ -211,7 +281,18 @@ class JambaChatCompletionsAdapter(ChatAdapter):
|
||||
if stream:
|
||||
return self._stream_response(response)
|
||||
|
||||
return [AIMessage(choice.message.content) for choice in response.choices]
|
||||
ai_messages: List[BaseMessage] = []
|
||||
for message in response.choices:
|
||||
if message.message.tool_calls:
|
||||
tool_calls = [
|
||||
parse_tool_call(tool_call.model_dump(), return_id=True)
|
||||
for tool_call in message.message.tool_calls
|
||||
]
|
||||
ai_messages.append(AIMessage("", tool_calls=tool_calls))
|
||||
else:
|
||||
ai_messages.append(AIMessage(message.message.content))
|
||||
|
||||
return ai_messages
|
||||
|
||||
def _stream_response(
|
||||
self,
|
||||
|
@ -1,11 +1,23 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
LangSmithParams,
|
||||
@ -16,6 +28,9 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
from langchain_ai21.chat.chat_adapter import ChatAdapter
|
||||
@ -48,14 +63,14 @@ class ChatAI21(BaseChatModel, AI21Base):
|
||||
stop: Optional[List[str]] = None
|
||||
"""Default stop sequences."""
|
||||
|
||||
max_tokens: int = 16
|
||||
max_tokens: int = 512
|
||||
"""The maximum number of tokens to generate for each response."""
|
||||
|
||||
min_tokens: int = 0
|
||||
"""The minimum number of tokens to generate for each response.
|
||||
_Not supported for all models._"""
|
||||
|
||||
temperature: float = 0.7
|
||||
temperature: float = 0.4
|
||||
"""A value controlling the "creativity" of the model's responses."""
|
||||
|
||||
top_p: float = 1
|
||||
@ -246,3 +261,11 @@ class ChatAI21(BaseChatModel, AI21Base):
|
||||
)
|
||||
|
||||
return message.content
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
75
libs/partners/ai21/poetry.lock
generated
75
libs/partners/ai21/poetry.lock
generated
@ -1,35 +1,36 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "ai21"
|
||||
version = "2.7.0"
|
||||
version = "2.14.1"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "ai21-2.7.0-py3-none-any.whl", hash = "sha256:9060aa90f0acc21ce1e3ad90c814762ba0914dd5af073c269868dbcdf5ecd108"},
|
||||
{file = "ai21-2.7.0.tar.gz", hash = "sha256:3f86f47af67fa43b086773aa01d89286ec2011dbc1a4a53aaca3a104ac1f958f"},
|
||||
{file = "ai21-2.14.1-py3-none-any.whl", hash = "sha256:618c0b5c025123c703645258472330a07ae2de17020438f6a33b29668275995c"},
|
||||
{file = "ai21-2.14.1.tar.gz", hash = "sha256:05d9626b82206e0a5be43d17c39d4e0c7b51c2b7634d2ea38a2c698ac3e2fd5b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
ai21-tokenizer = ">=0.11.0,<1.0.0"
|
||||
dataclasses-json = ">=0.6.3,<0.7.0"
|
||||
ai21-tokenizer = ">=0.12.0,<1.0.0"
|
||||
httpx = ">=0.27.0,<0.28.0"
|
||||
pydantic = ">=1.9.0,<3.0.0"
|
||||
tenacity = ">=8.3.0,<9.0.0"
|
||||
typing-extensions = ">=4.9.0,<5.0.0"
|
||||
|
||||
[package.extras]
|
||||
aws = ["boto3 (>=1.28.82,<2.0.0)"]
|
||||
vertex = ["google-auth (>=2.31.0,<3.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "ai21-tokenizer"
|
||||
version = "0.11.2"
|
||||
version = "0.12.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
files = [
|
||||
{file = "ai21_tokenizer-0.11.2-py3-none-any.whl", hash = "sha256:a9444ca44ef2bffec7cb9f0c3cfa5501dc973cdde0b740e43e137ce9a2f90eab"},
|
||||
{file = "ai21_tokenizer-0.11.2.tar.gz", hash = "sha256:35579bca375f071ae6365456f02bd5c9445f408723f7b87646a2bdaa3f57925e"},
|
||||
{file = "ai21_tokenizer-0.12.0-py3-none-any.whl", hash = "sha256:7fd37b9093894b30b0f200e5f44fc8fb8772e2b272ef71b6d73722b4696e63c4"},
|
||||
{file = "ai21_tokenizer-0.12.0.tar.gz", hash = "sha256:d2a5b17789d21572504b7693148bf66e692bdb3ab563023dbcbee340bcbd11c6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -211,21 +212,6 @@ files = [
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dataclasses-json"
|
||||
version = "0.6.6"
|
||||
description = "Easily serialize dataclasses to and from JSON."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.7"
|
||||
files = [
|
||||
{file = "dataclasses_json-0.6.6-py3-none-any.whl", hash = "sha256:e54c5c87497741ad454070ba0ed411523d46beb5da102e221efb873801b0ba85"},
|
||||
{file = "dataclasses_json-0.6.6.tar.gz", hash = "sha256:0c09827d26fffda27f1be2fed7a7a01a29c5ddcd2eb6393ad5ebf9d77e9deae8"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
marshmallow = ">=3.18.0,<4.0.0"
|
||||
typing-inspect = ">=0.4.0,<1"
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.1"
|
||||
@ -448,7 +434,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.2.11"
|
||||
version = "0.2.33"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -465,6 +451,7 @@ pydantic = [
|
||||
]
|
||||
PyYAML = ">=5.3"
|
||||
tenacity = "^8.1.0,!=8.4.0"
|
||||
typing-extensions = ">=4.7"
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@ -490,7 +477,7 @@ url = "../../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-text-splitters"
|
||||
version = "0.2.2"
|
||||
version = "0.2.3"
|
||||
description = "LangChain text splitting utilities"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -523,25 +510,6 @@ pydantic = [
|
||||
]
|
||||
requests = ">=2,<3"
|
||||
|
||||
[[package]]
|
||||
name = "marshmallow"
|
||||
version = "3.21.2"
|
||||
description = "A lightweight library for converting complex datatypes to and from native Python datatypes."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "marshmallow-3.21.2-py3-none-any.whl", hash = "sha256:70b54a6282f4704d12c0a41599682c5c5450e843b9ec406308653b47c59648a1"},
|
||||
{file = "marshmallow-3.21.2.tar.gz", hash = "sha256:82408deadd8b33d56338d2182d455db632c6313aa2af61916672146bb32edc56"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = ">=17.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["marshmallow[tests]", "pre-commit (>=3.5,<4.0)", "tox"]
|
||||
docs = ["alabaster (==0.7.16)", "autodocsumm (==0.2.12)", "sphinx (==7.3.7)", "sphinx-issues (==4.1.0)", "sphinx-version-warning (==1.1.2)"]
|
||||
tests = ["pytest", "pytz", "simplejson"]
|
||||
|
||||
[[package]]
|
||||
name = "mypy"
|
||||
version = "1.10.1"
|
||||
@ -1257,21 +1225,6 @@ files = [
|
||||
{file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspect"
|
||||
version = "0.9.0"
|
||||
description = "Runtime inspection utilities for typing module."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f"},
|
||||
{file = "typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mypy-extensions = ">=0.3.0"
|
||||
typing-extensions = ">=3.7.4"
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.2.1"
|
||||
@ -1336,4 +1289,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "9dee8f52fd10c8ffe640c33620d7a33daa989ba38f56f7a33a2e98a734457015"
|
||||
content-hash = "32e1777e151eef2eb3775c4f1707ec4e80241a08fd4beed2d60a36f58d68dfea"
|
||||
|
@ -22,7 +22,7 @@ disallow_untyped_defs = "True"
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.2.4"
|
||||
langchain-text-splitters = "^0.2.0"
|
||||
ai21 = "^2.7.0"
|
||||
ai21 = "^2.14.1"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [ "E", "F", "I",]
|
||||
|
@ -1,12 +1,25 @@
|
||||
"""Test ChatAI21 chat model."""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessageChunk,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
|
||||
from tests.unit_tests.conftest import (
|
||||
J2_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
|
||||
JAMBA_CHAT_MODEL_NAME,
|
||||
JAMBA_FAMILY_MODEL_NAMES,
|
||||
)
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||
|
||||
@ -15,11 +28,15 @@ rate_limiter = InMemoryRateLimiter(requests_per_second=0.5)
|
||||
ids=[
|
||||
"when_j2_model",
|
||||
"when_jamba_model",
|
||||
"when_jamba1.5-mini_model",
|
||||
"when_jamba1.5-large_model",
|
||||
],
|
||||
argnames=["model"],
|
||||
argvalues=[
|
||||
(J2_CHAT_MODEL_NAME,),
|
||||
(JAMBA_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
|
||||
],
|
||||
)
|
||||
def test_invoke(model: str) -> None:
|
||||
@ -36,6 +53,10 @@ def test_invoke(model: str) -> None:
|
||||
"when_j2_model_num_results_is_3",
|
||||
"when_jamba_model_n_is_1",
|
||||
"when_jamba_model_n_is_3",
|
||||
"when_jamba1.5_mini_model_n_is_1",
|
||||
"when_jamba1.5_mini_model_n_is_3",
|
||||
"when_jamba1.5_large_model_n_is_1",
|
||||
"when_jamba1.5_large_model_n_is_3",
|
||||
],
|
||||
argnames=["model", "num_results"],
|
||||
argvalues=[
|
||||
@ -43,12 +64,16 @@ def test_invoke(model: str) -> None:
|
||||
(J2_CHAT_MODEL_NAME, 3),
|
||||
(JAMBA_CHAT_MODEL_NAME, 1),
|
||||
(JAMBA_CHAT_MODEL_NAME, 3),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 1),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME, 3),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 1),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME, 3),
|
||||
],
|
||||
)
|
||||
def test_generation(model: str, num_results: int) -> None:
|
||||
"""Test generation with multiple models and different result counts."""
|
||||
# Determine the configuration key based on the model type
|
||||
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
|
||||
config_key = "n" if model in JAMBA_FAMILY_MODEL_NAMES else "num_results"
|
||||
|
||||
# Create the model instance using the appropriate key for the result count
|
||||
llm = ChatAI21(model=model, rate_limiter=rate_limiter, **{config_key: num_results}) # type: ignore[arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type, arg-type]
|
||||
@ -69,11 +94,15 @@ def test_generation(model: str, num_results: int) -> None:
|
||||
ids=[
|
||||
"when_j2_model",
|
||||
"when_jamba_model",
|
||||
"when_jamba1.5_mini_model",
|
||||
"when_jamba1.5_large_model",
|
||||
],
|
||||
argnames=["model"],
|
||||
argvalues=[
|
||||
(J2_CHAT_MODEL_NAME,),
|
||||
(JAMBA_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
|
||||
],
|
||||
)
|
||||
async def test_ageneration(model: str) -> None:
|
||||
@ -92,7 +121,7 @@ async def test_ageneration(model: str) -> None:
|
||||
|
||||
|
||||
def test__chat_stream() -> None:
|
||||
llm = ChatAI21(model="jamba-instruct") # type: ignore[call-arg]
|
||||
llm = ChatAI21(model="jamba-1.5-mini") # type: ignore[call-arg]
|
||||
message = HumanMessage(content="What is the meaning of life?")
|
||||
|
||||
for chunk in llm.stream([message]):
|
||||
@ -107,3 +136,53 @@ def test__j2_chat_stream__should_raise_error() -> None:
|
||||
with pytest.raises(NotImplementedError):
|
||||
for _ in llm.stream([message]):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_jamba1.5_mini_model",
|
||||
"when_jamba1.5_large_model",
|
||||
],
|
||||
argnames=["model"],
|
||||
argvalues=[
|
||||
(JAMBA_1_5_MINI_CHAT_MODEL_NAME,),
|
||||
(JAMBA_1_5_LARGE_CHAT_MODEL_NAME,),
|
||||
],
|
||||
)
|
||||
def test_tool_calls(model: str) -> None:
|
||||
@tool
|
||||
def get_weather(location: str, date: str) -> str:
|
||||
"""“Provide the weather for the specified location on the given date.”"""
|
||||
if location == "New York" and date == "2024-12-05":
|
||||
return "25 celsius"
|
||||
return "32 celsius"
|
||||
|
||||
llm = ChatAI21(model=model, temperature=0) # type: ignore[call-arg]
|
||||
llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])
|
||||
|
||||
chat_messages = [
|
||||
SystemMessage(
|
||||
content="You are a helpful assistant. "
|
||||
"You can use the provided tools "
|
||||
"to assist with various tasks and provide "
|
||||
"accurate information"
|
||||
),
|
||||
HumanMessage(
|
||||
content="What is the forecast for the weather "
|
||||
"in New York on December 5, 2024?"
|
||||
),
|
||||
]
|
||||
|
||||
response = llm_with_tools.invoke(chat_messages)
|
||||
chat_messages.append(response)
|
||||
assert response.tool_calls is not None # type: ignore[attr-defined]
|
||||
tool_call = response.tool_calls[0] # type: ignore[attr-defined]
|
||||
assert tool_call["name"] == "get_weather"
|
||||
|
||||
weather = get_weather.invoke( # type: ignore[attr-defined]
|
||||
{"location": tool_call["args"]["location"], "date": tool_call["args"]["date"]}
|
||||
)
|
||||
chat_messages.append(ToolMessage(content=weather, tool_call_id=tool_call["id"]))
|
||||
llm_answer = llm_with_tools.invoke(chat_messages)
|
||||
content = llm_answer.content.lower() # type: ignore[union-attr]
|
||||
assert "new york" in content and "25" in content and "celsius" in content
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
import time
|
||||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
@ -28,6 +28,8 @@ class BaseTestAI21(ChatModelIntegrationTests):
|
||||
|
||||
|
||||
class TestAI21J2(BaseTestAI21):
|
||||
has_tool_calling = False
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
@ -49,8 +51,23 @@ class TestAI21J2(BaseTestAI21):
|
||||
|
||||
|
||||
class TestAI21Jamba(BaseTestAI21):
|
||||
has_tool_calling = False
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "jamba-instruct-preview",
|
||||
}
|
||||
|
||||
|
||||
class TestAI21Jamba1_5(BaseTestAI21):
|
||||
@property
|
||||
def tool_choice_value(self) -> Optional[str]:
|
||||
"""Value to use for tool choice when used in tests."""
|
||||
return "any"
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {
|
||||
"model": "jamba-1.5-mini",
|
||||
}
|
||||
|
@ -3,12 +3,21 @@ from typing import List
|
||||
import pytest
|
||||
from ai21.models import ChatMessage as J2ChatMessage
|
||||
from ai21.models import RoleType
|
||||
from ai21.models.chat import ChatMessage
|
||||
from ai21.models.chat import (
|
||||
AssistantMessage,
|
||||
ChatMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from ai21.models.chat import (
|
||||
SystemMessage as AI21SystemMessage,
|
||||
)
|
||||
from ai21.models.chat import ToolMessage as AI21ToolMessage
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
ChatMessage as LangChainChatMessage,
|
||||
@ -18,6 +27,8 @@ from langchain_ai21.chat.chat_adapter import ChatAdapter
|
||||
|
||||
_J2_MODEL_NAME = "j2-ultra"
|
||||
_JAMBA_MODEL_NAME = "jamba-instruct-preview"
|
||||
_JAMBA_1_5_MINI_MODEL_NAME = "jamba-1.5-mini"
|
||||
_JAMBA_1_5_LARGE_MODEL_NAME = "jamba-1.5-large"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -42,12 +53,14 @@ _JAMBA_MODEL_NAME = "jamba-instruct-preview"
|
||||
(
|
||||
_JAMBA_MODEL_NAME,
|
||||
HumanMessage(content="Human Message Content"),
|
||||
ChatMessage(role=RoleType.USER, content="Human Message Content"),
|
||||
UserMessage(role="user", content="Human Message Content"),
|
||||
),
|
||||
(
|
||||
_JAMBA_MODEL_NAME,
|
||||
AIMessage(content="AI Message Content"),
|
||||
ChatMessage(role=RoleType.ASSISTANT, content="AI Message Content"),
|
||||
AssistantMessage(
|
||||
role="assistant", content="AI Message Content", tool_calls=[]
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -69,7 +82,7 @@ def test_convert_message_to_ai21_message(
|
||||
argvalues=[
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
SystemMessage(content="System Message Content"),
|
||||
AI21SystemMessage(content="System Message Content"),
|
||||
),
|
||||
(
|
||||
_J2_MODEL_NAME,
|
||||
@ -95,6 +108,8 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
|
||||
"when_first_message_is_system__should_return_system_j2_model",
|
||||
"when_all_messages_are_human_messages__should_return_system_none_jamba_model",
|
||||
"when_first_message_is_system__should_return_system_jamba_model",
|
||||
"when_tool_calling_message__should_return_tool_jamba_mini_model",
|
||||
"when_tool_calling_message__should_return_tool_jamba_large_model",
|
||||
],
|
||||
argnames=["model", "messages", "expected_messages"],
|
||||
argvalues=[
|
||||
@ -142,12 +157,12 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
ChatMessage(
|
||||
role=RoleType.USER,
|
||||
UserMessage(
|
||||
role="user",
|
||||
content="Human Message Content 1",
|
||||
),
|
||||
ChatMessage(
|
||||
role=RoleType.USER,
|
||||
UserMessage(
|
||||
role="user",
|
||||
content="Human Message Content 2",
|
||||
),
|
||||
]
|
||||
@ -161,8 +176,46 @@ def test_convert_message_to_ai21_message__when_invalid_role__should_raise_except
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
ChatMessage(role="system", content="System Message Content 1"),
|
||||
ChatMessage(role="user", content="Human Message Content 1"),
|
||||
AI21SystemMessage(
|
||||
role="system", content="System Message Content 1"
|
||||
),
|
||||
UserMessage(role="user", content="Human Message Content 1"),
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
_JAMBA_1_5_MINI_MODEL_NAME,
|
||||
[
|
||||
ToolMessage(
|
||||
content="42",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
)
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
AI21ToolMessage(
|
||||
role="tool",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
content="42",
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
_JAMBA_1_5_LARGE_MODEL_NAME,
|
||||
[
|
||||
ToolMessage(
|
||||
content="42",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
)
|
||||
],
|
||||
{
|
||||
"messages": [
|
||||
AI21ToolMessage(
|
||||
role="tool",
|
||||
tool_call_id="call_Jja7J89XsjrOLA5r!MEOW!SL",
|
||||
content="42",
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
|
@ -23,8 +23,16 @@ from pytest_mock import MockerFixture
|
||||
|
||||
J2_CHAT_MODEL_NAME = "j2-ultra"
|
||||
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
|
||||
JAMBA_1_5_MINI_CHAT_MODEL_NAME = "jamba-1.5-mini"
|
||||
JAMBA_1_5_LARGE_CHAT_MODEL_NAME = "jamba-1.5-large"
|
||||
DUMMY_API_KEY = "test_api_key"
|
||||
|
||||
JAMBA_FAMILY_MODEL_NAMES = [
|
||||
JAMBA_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_MINI_CHAT_MODEL_NAME,
|
||||
JAMBA_1_5_LARGE_CHAT_MODEL_NAME,
|
||||
]
|
||||
|
||||
BASIC_EXAMPLE_LLM_PARAMETERS = {
|
||||
"num_results": 3,
|
||||
"max_tokens": 20,
|
||||
@ -32,9 +40,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
|
||||
"count_penalty": Penalty(
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
@ -48,9 +56,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = {
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
|
||||
"count_penalty": Penalty(
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
@ -59,7 +67,7 @@ BASIC_EXAMPLE_CHAT_PARAMETERS = {
|
||||
}
|
||||
|
||||
SEGMENTS = [
|
||||
Segment(
|
||||
Segment( # type: ignore[call-arg]
|
||||
segment_type="normal_text",
|
||||
segment_text=(
|
||||
"The original full name of the franchise is Pocket Monsters "
|
||||
@ -70,7 +78,7 @@ SEGMENTS = [
|
||||
"in pronunciation."
|
||||
),
|
||||
),
|
||||
Segment(
|
||||
Segment( # type: ignore[call-arg]
|
||||
segment_type="normal_text",
|
||||
segment_text=(
|
||||
"Pokémon refers to both the franchise itself and the creatures "
|
||||
@ -92,9 +100,9 @@ BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
|
||||
"count_penalty": Penalty(
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
@ -108,9 +116,9 @@ BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
|
||||
"temperature": 0.5,
|
||||
"top_p": 0.5,
|
||||
"top_k_return": 0,
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
|
||||
"count_penalty": Penalty(
|
||||
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(), # type: ignore[call-arg]
|
||||
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(), # type: ignore[call-arg]
|
||||
"count_penalty": Penalty( # type: ignore[call-arg]
|
||||
scale=0.2,
|
||||
apply_to_punctuation=True,
|
||||
apply_to_emojis=True,
|
||||
@ -124,7 +132,7 @@ def mocked_completion_response(mocker: MockerFixture) -> Mock:
|
||||
mocked_response = mocker.MagicMock(spec=CompletionsResponse)
|
||||
mocked_response.prompt = "this is a test prompt"
|
||||
mocked_response.completions = [
|
||||
Completion(
|
||||
Completion( # type: ignore[call-arg]
|
||||
data=CompletionData(text="test", tokens=[]),
|
||||
finish_reason=CompletionFinishReason(reason=None, length=None),
|
||||
)
|
||||
@ -152,7 +160,7 @@ def mock_client_with_chat(mocker: MockerFixture) -> Mock:
|
||||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||||
mock_client.chat = mocker.MagicMock()
|
||||
|
||||
output = ChatOutput(
|
||||
output = ChatOutput( # type: ignore[call-arg]
|
||||
text="Hello Pickle Rick!",
|
||||
role=RoleType.ASSISTANT,
|
||||
finish_reason=FinishReason(reason="testing"),
|
||||
@ -178,7 +186,7 @@ def temporarily_unset_api_key() -> Generator:
|
||||
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
|
||||
mock_client = mocker.MagicMock(spec=AI21Client)
|
||||
mock_client.answer = mocker.MagicMock()
|
||||
mock_client.answer.create.return_value = AnswerResponse(
|
||||
mock_client.answer.create.return_value = AnswerResponse( # type: ignore[call-arg]
|
||||
id="some_id",
|
||||
answer="some answer",
|
||||
answer_in_context=False,
|
||||
|
@ -45,9 +45,9 @@ def test_initialization__when_custom_parameters_in_init() -> None:
|
||||
temperature = 0.1
|
||||
top_p = 0.1
|
||||
top_k_return = 0
|
||||
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True)
|
||||
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True)
|
||||
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True)
|
||||
frequency_penalty = Penalty(scale=0.2, apply_to_numbers=True) # type: ignore[call-arg]
|
||||
presence_penalty = Penalty(scale=0.2, apply_to_stopwords=True) # type: ignore[call-arg]
|
||||
count_penalty = Penalty(scale=0.2, apply_to_punctuation=True, apply_to_emojis=True) # type: ignore[call-arg]
|
||||
|
||||
llm = ChatAI21( # type: ignore[call-arg]
|
||||
api_key=DUMMY_API_KEY, # type: ignore[arg-type]
|
||||
|
@ -17,9 +17,9 @@ _EXAMPLE_EMBEDDING_2 = [7.0, 8.0, 9.0]
|
||||
|
||||
_EXAMPLE_EMBEDDING_RESPONSE = EmbedResponse(
|
||||
results=[
|
||||
EmbedResult(_EXAMPLE_EMBEDDING_0),
|
||||
EmbedResult(_EXAMPLE_EMBEDDING_1),
|
||||
EmbedResult(_EXAMPLE_EMBEDDING_2),
|
||||
EmbedResult(embedding=_EXAMPLE_EMBEDDING_0),
|
||||
EmbedResult(embedding=_EXAMPLE_EMBEDDING_1),
|
||||
EmbedResult(embedding=_EXAMPLE_EMBEDDING_2),
|
||||
],
|
||||
id="test_id",
|
||||
)
|
||||
|
@ -49,9 +49,9 @@ def test_initialization__when_custom_parameters_to_init() -> None:
|
||||
top_p=0.5,
|
||||
top_k_return=0,
|
||||
stop_sequences=["\n"],
|
||||
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True),
|
||||
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True),
|
||||
count_penalty=Penalty(
|
||||
frequency_penalty=Penalty(scale=0.2, apply_to_numbers=True), # type: ignore[call-arg]
|
||||
presence_penalty=Penalty(scale=0.2, apply_to_stopwords=True), # type: ignore[call-arg]
|
||||
count_penalty=Penalty( # type: ignore[call-arg]
|
||||
scale=0.2, apply_to_punctuation=True, apply_to_emojis=True
|
||||
),
|
||||
custom_model="test_model",
|
||||
|
Loading…
Reference in New Issue
Block a user