partners: AI21 Labs Jamba Support (#20815)

Description: Added support for AI21 new model - Jamba
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
pull/21151/head
Asaf Joseph Gardin 2 months ago committed by GitHub
parent 7a39fe60da
commit 642975dd9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,130 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Union, cast
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import ChatMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
_ChatMessageTypes = Union[ChatMessage, J2ChatMessage]
_SYSTEM_ERR_MESSAGE = "System message must be at beginning of message list."
_ROLE_TYPE = Union[str, RoleType]
class ChatAdapter(ABC):
"""
Provides a common interface for the different Chat models available in AI21.
It converts LangChain messages to AI21 messages.
Calls the appropriate AI21 model API with the converted messages.
"""
@abstractmethod
def convert_messages(
self,
messages: List[BaseMessage],
) -> Dict[str, Any]:
pass
def _convert_message_to_ai21_message(
self,
message: BaseMessage,
) -> _ChatMessageTypes:
content = cast(str, message.content)
role = self._parse_role(message)
return self._chat_message(role=role, content=content)
def _parse_role(self, message: BaseMessage) -> _ROLE_TYPE:
role = None
if isinstance(message, HumanMessage):
return RoleType.USER
if isinstance(message, AIMessage):
return RoleType.ASSISTANT
if isinstance(self, J2ChatAdapter):
if not role:
raise ValueError(
f"Could not resolve role type from message {message}. "
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
)
# if it gets here, we rely on the server to handle the role type
return message.type
@abstractmethod
def _chat_message(
self,
role: _ROLE_TYPE,
content: str,
) -> _ChatMessageTypes:
pass
@abstractmethod
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
pass
def _get_system_message_from_message(self, message: BaseMessage) -> str:
if not isinstance(message.content, str):
raise ValueError(
f"System Message must be of type str. Got {type(message.content)}"
)
return message.content
class J2ChatAdapter(ChatAdapter):
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
system_message = ""
converted_messages = [] # type: ignore
for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError(_SYSTEM_ERR_MESSAGE)
else:
system_message = self._get_system_message_from_message(message)
else:
converted_message = self._convert_message_to_ai21_message(message)
converted_messages.append(converted_message)
return {"system": system_message, "messages": converted_messages}
def _chat_message(
self,
role: _ROLE_TYPE,
content: str,
) -> J2ChatMessage:
return J2ChatMessage(role=RoleType(role), text=content)
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
response = client.chat.create(**params)
return [AIMessage(output.text) for output in response.outputs]
class JambaChatCompletionsAdapter(ChatAdapter):
def convert_messages(self, messages: List[BaseMessage]) -> Dict[str, Any]:
return {
"messages": [
self._convert_message_to_ai21_message(message) for message in messages
],
}
def _chat_message(
self,
role: _ROLE_TYPE,
content: str,
) -> ChatMessage:
return ChatMessage(
role=role.value if isinstance(role, RoleType) else role,
content=content,
)
def call(self, client: Any, **params: Any) -> List[BaseMessage]:
response = client.chat.completions.create(**params)
return [AIMessage(choice.message.content) for choice in response.choices]

@ -0,0 +1,15 @@
from langchain_ai21.chat.chat_adapter import (
ChatAdapter,
J2ChatAdapter,
JambaChatCompletionsAdapter,
)
def create_chat_adapter(model: str) -> ChatAdapter:
if "j2" in model:
return J2ChatAdapter()
if "jamba" in model:
return JambaChatCompletionsAdapter()
raise ValueError(f"Model {model} not supported.")

@ -1,79 +1,21 @@
import asyncio
from functools import partial
from typing import Any, List, Mapping, Optional, Tuple, cast
from typing import Any, Dict, List, Mapping, Optional
from ai21.models import ChatMessage, RoleType
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import root_validator
from langchain_ai21.ai21_base import AI21Base
def _get_system_message_from_message(message: BaseMessage) -> str:
if not isinstance(message.content, str):
raise ValueError(
f"System Message must be of type str. Got {type(message.content)}"
)
return message.content
def _convert_messages_to_ai21_messages(
messages: List[BaseMessage],
) -> Tuple[Optional[str], List[ChatMessage]]:
system_message = None
converted_messages: List[ChatMessage] = []
for i, message in enumerate(messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
else:
system_message = _get_system_message_from_message(message)
else:
converted_message = _convert_message_to_ai21_message(message)
converted_messages.append(converted_message)
return system_message, converted_messages
def _convert_message_to_ai21_message(
message: BaseMessage,
) -> ChatMessage:
content = cast(str, message.content)
role = None
if isinstance(message, HumanMessage):
role = RoleType.USER
elif isinstance(message, AIMessage):
role = RoleType.ASSISTANT
if not role:
raise ValueError(
f"Could not resolve role type from message {message}. "
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
)
return ChatMessage(role=role, text=content)
def _pop_system_messages(messages: List[BaseMessage]) -> List[SystemMessage]:
system_message_indexes = [
i for i, message in enumerate(messages) if isinstance(message, SystemMessage)
]
return [cast(SystemMessage, messages.pop(i)) for i in system_message_indexes]
from langchain_ai21.chat.chat_adapter import ChatAdapter
from langchain_ai21.chat.chat_factory import create_chat_adapter
class ChatAI21(BaseChatModel, AI21Base):
@ -119,6 +61,20 @@ class ChatAI21(BaseChatModel, AI21Base):
"""A penalty applied to tokens based on their frequency
in the generated responses."""
n: int = 1
"""Number of chat completions to generate for each prompt."""
_chat_adapter: ChatAdapter
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values = super().validate_environment(values)
model = values.get("model")
values["_chat_adapter"] = create_chat_adapter(model) # type: ignore
return values
class Config:
"""Configuration for this pydantic object."""
@ -139,6 +95,7 @@ class ChatAI21(BaseChatModel, AI21Base):
"temperature": self.temperature,
"top_p": self.top_p,
"top_k_return": self.top_k_return,
"n": self.n,
}
if self.count_penalty is not None:
@ -159,7 +116,7 @@ class ChatAI21(BaseChatModel, AI21Base):
**kwargs: Any,
) -> Mapping[str, Any]:
params = {}
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
converted_messages = self._chat_adapter.convert_messages(messages)
if stop is not None:
if "stop" in kwargs:
@ -167,8 +124,7 @@ class ChatAI21(BaseChatModel, AI21Base):
params["stop_sequences"] = stop
return {
"system": system or "",
"messages": ai21_messages,
**converted_messages,
**self._default_params,
**params,
**kwargs,
@ -182,12 +138,10 @@ class ChatAI21(BaseChatModel, AI21Base):
**kwargs: Any,
) -> ChatResult:
params = self._build_params_for_request(messages=messages, stop=stop, **kwargs)
messages = self._chat_adapter.call(self.client, **params)
generations = [ChatGeneration(message=message) for message in messages]
response = self.client.chat.create(**params)
outputs = response.outputs
message = AIMessage(content=outputs[0].text)
return ChatResult(generations=[ChatGeneration(message=message)])
return ChatResult(generations=generations)
async def _agenerate(
self,
@ -199,3 +153,11 @@ class ChatAI21(BaseChatModel, AI21Base):
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, **kwargs), messages, stop, run_manager
)
def _get_system_message_from_message(self, message: BaseMessage) -> str:
if not isinstance(message.content, str):
raise ValueError(
f"System Message must be of type str. Got {type(message.content)}"
)
return message.content

@ -2,17 +2,17 @@
[[package]]
name = "ai21"
version = "2.1.3"
version = "2.2.1"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
files = [
{file = "ai21-2.1.3-py3-none-any.whl", hash = "sha256:384f3b5769edfb7124e9288bc83daeeeee35f6f831be2f77425c061774d5ade3"},
{file = "ai21-2.1.3.tar.gz", hash = "sha256:80014f54818453e87ced5c3d180e22a6cfb25911231ec834cae55f8769106afe"},
{file = "ai21-2.2.1-py3-none-any.whl", hash = "sha256:c2727656b07a1284fc8f1f2fb1b065059f6fa5fcab8af2f45b0767086a044771"},
{file = "ai21-2.2.1.tar.gz", hash = "sha256:69e659071e9bbadae7349465bc388312bcff30e22fb7d7a9de4319e2925acfff"},
]
[package.dependencies]
ai21-tokenizer = ">=0.3.9,<0.4.0"
ai21-tokenizer = ">=0.9.0,<0.10.0"
dataclasses-json = ">=0.6.3,<0.7.0"
requests = ">=2.31.0,<3.0.0"
typing-extensions = ">=4.9.0,<5.0.0"
@ -22,17 +22,18 @@ aws = ["boto3 (>=1.28.82,<2.0.0)"]
[[package]]
name = "ai21-tokenizer"
version = "0.3.11"
version = "0.9.0"
description = ""
optional = false
python-versions = ">=3.7,<4.0"
python-versions = "<4.0,>=3.7"
files = [
{file = "ai21_tokenizer-0.3.11-py3-none-any.whl", hash = "sha256:80d332c51cab3fa88f0fea7493240a6a5bc38fd24a3d0806d28731d8fc97691f"},
{file = "ai21_tokenizer-0.3.11.tar.gz", hash = "sha256:ec11ce4e46d24f71f1c2756ad0de34e0adfd51b5bcd81b544aea13d6935ec905"},
{file = "ai21_tokenizer-0.9.0-py3-none-any.whl", hash = "sha256:3e4927f0ef98923f53710405fb476b29102bc7c677c65243043c5671cb642d9e"},
{file = "ai21_tokenizer-0.9.0.tar.gz", hash = "sha256:724d368eb74564950edfcd6590da7c3b702c4a8f59f60d36dbec5046e6856579"},
]
[package.dependencies]
sentencepiece = ">=0.1.96,<0.2.0"
sentencepiece = ">=0.2.0,<0.3.0"
tokenizers = ">=0.15.2,<0.16.0"
[[package]]
name = "annotated-types"
@ -215,6 +216,22 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "filelock"
version = "3.13.4"
description = "A platform independent file lock."
optional = false
python-versions = ">=3.8"
files = [
{file = "filelock-3.13.4-py3-none-any.whl", hash = "sha256:404e5e9253aa60ad457cae1be07c0f0ca90a63931200a47d9b6a6af84fd7b45f"},
{file = "filelock-3.13.4.tar.gz", hash = "sha256:d13f466618bfde72bd2c18255e269f72542c6e70e7bac83a0232d6b1cc5c8cf4"},
]
[package.extras]
docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"]
testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"]
typing = ["typing-extensions (>=4.8)"]
[[package]]
name = "freezegun"
version = "1.4.0"
@ -229,6 +246,75 @@ files = [
[package.dependencies]
python-dateutil = ">=2.7"
[[package]]
name = "fsspec"
version = "2024.3.1"
description = "File-system specification"
optional = false
python-versions = ">=3.8"
files = [
{file = "fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512"},
{file = "fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9"},
]
[package.extras]
abfs = ["adlfs"]
adl = ["adlfs"]
arrow = ["pyarrow (>=1)"]
dask = ["dask", "distributed"]
devel = ["pytest", "pytest-cov"]
dropbox = ["dropbox", "dropboxdrivefs", "requests"]
full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
fuse = ["fusepy"]
gcs = ["gcsfs"]
git = ["pygit2"]
github = ["requests"]
gs = ["gcsfs"]
gui = ["panel"]
hdfs = ["pyarrow (>=1)"]
http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"]
libarchive = ["libarchive-c"]
oci = ["ocifs"]
s3 = ["s3fs"]
sftp = ["paramiko"]
smb = ["smbprotocol"]
ssh = ["paramiko"]
tqdm = ["tqdm"]
[[package]]
name = "huggingface-hub"
version = "0.22.2"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.22.2-py3-none-any.whl", hash = "sha256:3429e25f38ccb834d310804a3b711e7e4953db5a9e420cc147a5e194ca90fd17"},
{file = "huggingface_hub-0.22.2.tar.gz", hash = "sha256:32e9a9a6843c92f253ff9ca16b9985def4d80a93fb357af5353f770ef74a81be"},
]
[package.dependencies]
filelock = "*"
fsspec = ">=2023.5.0"
packaging = ">=20.9"
pyyaml = ">=5.1"
requests = "*"
tqdm = ">=4.42.1"
typing-extensions = ">=3.7.4.3"
[package.extras]
all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
cli = ["InquirerPy (==0.3.4)"]
dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "mypy (==1.5.1)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.3.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"]
fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
hf-transfer = ["hf-transfer (>=0.1.4)"]
inference = ["aiohttp", "minijinja (>=1.0)"]
quality = ["mypy (==1.5.1)", "ruff (>=0.3.0)"]
tensorflow = ["graphviz", "pydot", "tensorflow"]
tensorflow-testing = ["keras (<3.0)", "tensorflow"]
testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "gradio", "jedi", "minijinja (>=1.0)", "numpy", "pytest", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"]
torch = ["safetensors", "torch"]
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"]
[[package]]
name = "idna"
version = "3.6"
@ -278,7 +364,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.42"
version = "0.1.45"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -737,7 +823,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -821,56 +906,64 @@ files = [
[[package]]
name = "sentencepiece"
version = "0.1.99"
version = "0.2.0"
description = "SentencePiece python wrapper"
optional = false
python-versions = "*"
files = [
{file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0eb528e70571b7c02723e5804322469b82fe7ea418c96051d0286c0fa028db73"},
{file = "sentencepiece-0.1.99-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:77d7fafb2c4e4659cbdf303929503f37a26eabc4ff31d3a79bf1c5a1b338caa7"},
{file = "sentencepiece-0.1.99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be9cf5b9e404c245aeb3d3723c737ba7a8f5d4ba262ef233a431fa6c45f732a0"},
{file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baed1a26464998f9710d20e52607c29ffd4293e7c71c6a1f83f51ad0911ec12c"},
{file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9832f08bb372d4c8b567612f8eab9e36e268dff645f1c28f9f8e851be705f6d1"},
{file = "sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:019e7535108e309dae2b253a75834fc3128240aa87c00eb80732078cdc182588"},
{file = "sentencepiece-0.1.99-cp310-cp310-win32.whl", hash = "sha256:fa16a830416bb823fa2a52cbdd474d1f7f3bba527fd2304fb4b140dad31bb9bc"},
{file = "sentencepiece-0.1.99-cp310-cp310-win_amd64.whl", hash = "sha256:14b0eccb7b641d4591c3e12ae44cab537d68352e4d3b6424944f0c447d2348d5"},
{file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6d3c56f24183a1e8bd61043ff2c58dfecdc68a5dd8955dc13bab83afd5f76b81"},
{file = "sentencepiece-0.1.99-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed6ea1819fd612c989999e44a51bf556d0ef6abfb553080b9be3d347e18bcfb7"},
{file = "sentencepiece-0.1.99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a2a0260cd1fb7bd8b4d4f39dc2444a8d5fd4e0a0c4d5c899810ef1abf99b2d45"},
{file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a1abff4d1ff81c77cac3cc6fefa34fa4b8b371e5ee51cb7e8d1ebc996d05983"},
{file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:004e6a621d4bc88978eecb6ea7959264239a17b70f2cbc348033d8195c9808ec"},
{file = "sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db361e03342c41680afae5807590bc88aa0e17cfd1a42696a160e4005fcda03b"},
{file = "sentencepiece-0.1.99-cp311-cp311-win32.whl", hash = "sha256:2d95e19168875b70df62916eb55428a0cbcb834ac51d5a7e664eda74def9e1e0"},
{file = "sentencepiece-0.1.99-cp311-cp311-win_amd64.whl", hash = "sha256:f90d73a6f81248a909f55d8e6ef56fec32d559e1e9af045f0b0322637cb8e5c7"},
{file = "sentencepiece-0.1.99-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:62e24c81e74bd87a6e0d63c51beb6527e4c0add67e1a17bac18bcd2076afcfeb"},
{file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:57efcc2d51caff20d9573567d9fd3f854d9efe613ed58a439c78c9f93101384a"},
{file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6a904c46197993bd1e95b93a6e373dca2f170379d64441041e2e628ad4afb16f"},
{file = "sentencepiece-0.1.99-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d89adf59854741c0d465f0e1525b388c0d174f611cc04af54153c5c4f36088c4"},
{file = "sentencepiece-0.1.99-cp36-cp36m-win32.whl", hash = "sha256:47c378146928690d1bc106fdf0da768cebd03b65dd8405aa3dd88f9c81e35dba"},
{file = "sentencepiece-0.1.99-cp36-cp36m-win_amd64.whl", hash = "sha256:9ba142e7a90dd6d823c44f9870abdad45e6c63958eb60fe44cca6828d3b69da2"},
{file = "sentencepiece-0.1.99-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b7b1a9ae4d7c6f1f867e63370cca25cc17b6f4886729595b885ee07a58d3cec3"},
{file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0f644c9d4d35c096a538507b2163e6191512460035bf51358794a78515b74f7"},
{file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8843d23a0f686d85e569bd6dcd0dd0e0cbc03731e63497ca6d5bacd18df8b85"},
{file = "sentencepiece-0.1.99-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e6f690a1caebb4867a2e367afa1918ad35be257ecdb3455d2bbd787936f155"},
{file = "sentencepiece-0.1.99-cp37-cp37m-win32.whl", hash = "sha256:8a321866c2f85da7beac74a824b4ad6ddc2a4c9bccd9382529506d48f744a12c"},
{file = "sentencepiece-0.1.99-cp37-cp37m-win_amd64.whl", hash = "sha256:c42f753bcfb7661c122a15b20be7f684b61fc8592c89c870adf52382ea72262d"},
{file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:85b476406da69c70586f0bb682fcca4c9b40e5059814f2db92303ea4585c650c"},
{file = "sentencepiece-0.1.99-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cfbcfe13c69d3f87b7fcd5da168df7290a6d006329be71f90ba4f56bc77f8561"},
{file = "sentencepiece-0.1.99-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:445b0ec381af1cd4eef95243e7180c63d9c384443c16c4c47a28196bd1cda937"},
{file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6890ea0f2b4703f62d0bf27932e35808b1f679bdb05c7eeb3812b935ba02001"},
{file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb71af492b0eefbf9f2501bec97bcd043b6812ab000d119eaf4bd33f9e283d03"},
{file = "sentencepiece-0.1.99-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27b866b5bd3ddd54166bbcbf5c8d7dd2e0b397fac8537991c7f544220b1f67bc"},
{file = "sentencepiece-0.1.99-cp38-cp38-win32.whl", hash = "sha256:b133e8a499eac49c581c3c76e9bdd08c338cc1939e441fee6f92c0ccb5f1f8be"},
{file = "sentencepiece-0.1.99-cp38-cp38-win_amd64.whl", hash = "sha256:0eaf3591dd0690a87f44f4df129cf8d05d8a4029b5b6709b489b8e27f9a9bcff"},
{file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:38efeda9bbfb55052d482a009c6a37e52f42ebffcea9d3a98a61de7aee356a28"},
{file = "sentencepiece-0.1.99-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c030b081dc1e1bcc9fadc314b19b740715d3d566ad73a482da20d7d46fd444c"},
{file = "sentencepiece-0.1.99-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:84dbe53e02e4f8a2e45d2ac3e430d5c83182142658e25edd76539b7648928727"},
{file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b0f55d0a0ee1719b4b04221fe0c9f0c3461dc3dabd77a035fa2f4788eb3ef9a"},
{file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e800f206cd235dc27dc749299e05853a4e4332e8d3dfd81bf13d0e5b9007d9"},
{file = "sentencepiece-0.1.99-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ae1c40cda8f9d5b0423cfa98542735c0235e7597d79caf318855cdf971b2280"},
{file = "sentencepiece-0.1.99-cp39-cp39-win32.whl", hash = "sha256:c84ce33af12ca222d14a1cdd37bd76a69401e32bc68fe61c67ef6b59402f4ab8"},
{file = "sentencepiece-0.1.99-cp39-cp39-win_amd64.whl", hash = "sha256:350e5c74d739973f1c9643edb80f7cc904dc948578bcb1d43c6f2b173e5d18dd"},
{file = "sentencepiece-0.1.99.tar.gz", hash = "sha256:189c48f5cb2949288f97ccdb97f0473098d9c3dcf5a3d99d4eabe719ec27297f"},
{file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"},
{file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"},
{file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"},
{file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"},
{file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"},
{file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"},
{file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"},
{file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"},
{file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"},
{file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"},
{file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"},
{file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"},
{file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"},
{file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"},
{file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"},
{file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"},
{file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"},
{file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"},
{file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"},
{file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"},
{file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"},
{file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"},
{file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"},
{file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"},
{file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"},
{file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"},
{file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"},
{file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"},
{file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"},
{file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"},
{file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"},
{file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"},
{file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"},
{file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"},
{file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"},
{file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"},
{file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"},
{file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"},
{file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"},
{file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"},
{file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"},
{file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"},
{file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"},
{file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"},
{file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"},
{file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"},
{file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"},
{file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"},
{file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"},
{file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"},
{file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"},
{file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"},
{file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"},
]
[[package]]
@ -912,6 +1005,133 @@ files = [
[package.extras]
doc = ["reno", "sphinx", "tornado (>=4.5)"]
[[package]]
name = "tokenizers"
version = "0.15.2"
description = ""
optional = false
python-versions = ">=3.7"
files = [
{file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"},
{file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"},
{file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a9b9b070fdad06e347563b88c278995735292ded1132f8657084989a4c84a6d5"},
{file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea621a7eef4b70e1f7a4e84dd989ae3f0eeb50fc8690254eacc08acb623e82f1"},
{file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cf7fd9a5141634fa3aa8d6b7be362e6ae1b4cda60da81388fa533e0b552c98fd"},
{file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:44f2a832cd0825295f7179eaf173381dc45230f9227ec4b44378322d900447c9"},
{file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8b9ec69247a23747669ec4b0ca10f8e3dfb3545d550258129bd62291aabe8605"},
{file = "tokenizers-0.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b6a4c78da863ff26dbd5ad9a8ecc33d8a8d97b535172601cf00aee9d7ce9ce"},
{file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5ab2a4d21dcf76af60e05af8063138849eb1d6553a0d059f6534357bce8ba364"},
{file = "tokenizers-0.15.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a47acfac7e511f6bbfcf2d3fb8c26979c780a91e06fb5b9a43831b2c0153d024"},
{file = "tokenizers-0.15.2-cp310-none-win32.whl", hash = "sha256:064ff87bb6acdbd693666de9a4b692add41308a2c0ec0770d6385737117215f2"},
{file = "tokenizers-0.15.2-cp310-none-win_amd64.whl", hash = "sha256:3b919afe4df7eb6ac7cafd2bd14fb507d3f408db7a68c43117f579c984a73843"},
{file = "tokenizers-0.15.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:89cd1cb93e4b12ff39bb2d626ad77e35209de9309a71e4d3d4672667b4b256e7"},
{file = "tokenizers-0.15.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cfed5c64e5be23d7ee0f0e98081a25c2a46b0b77ce99a4f0605b1ec43dd481fa"},
{file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a907d76dcfda37023ba203ab4ceeb21bc5683436ebefbd895a0841fd52f6f6f2"},
{file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:20ea60479de6fc7b8ae756b4b097572372d7e4032e2521c1bbf3d90c90a99ff0"},
{file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:48e2b9335be2bc0171df9281385c2ed06a15f5cf121c44094338306ab7b33f2c"},
{file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:112a1dd436d2cc06e6ffdc0b06d55ac019a35a63afd26475205cb4b1bf0bfbff"},
{file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4620cca5c2817177ee8706f860364cc3a8845bc1e291aaf661fb899e5d1c45b0"},
{file = "tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccd73a82751c523b3fc31ff8194702e4af4db21dc20e55b30ecc2079c5d43cb7"},
{file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:107089f135b4ae7817affe6264f8c7a5c5b4fd9a90f9439ed495f54fcea56fb4"},
{file = "tokenizers-0.15.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0ff110ecc57b7aa4a594396525a3451ad70988e517237fe91c540997c4e50e29"},
{file = "tokenizers-0.15.2-cp311-none-win32.whl", hash = "sha256:6d76f00f5c32da36c61f41c58346a4fa7f0a61be02f4301fd30ad59834977cc3"},
{file = "tokenizers-0.15.2-cp311-none-win_amd64.whl", hash = "sha256:cc90102ed17271cf0a1262babe5939e0134b3890345d11a19c3145184b706055"},
{file = "tokenizers-0.15.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f86593c18d2e6248e72fb91c77d413a815153b8ea4e31f7cd443bdf28e467670"},
{file = "tokenizers-0.15.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0774bccc6608eca23eb9d620196687c8b2360624619623cf4ba9dc9bd53e8b51"},
{file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d0222c5b7c9b26c0b4822a82f6a7011de0a9d3060e1da176f66274b70f846b98"},
{file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3835738be1de66624fff2f4f6f6684775da4e9c00bde053be7564cbf3545cc66"},
{file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0143e7d9dcd811855c1ce1ab9bf5d96d29bf5e528fd6c7824d0465741e8c10fd"},
{file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db35825f6d54215f6b6009a7ff3eedee0848c99a6271c870d2826fbbedf31a38"},
{file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3f5e64b0389a2be47091d8cc53c87859783b837ea1a06edd9d8e04004df55a5c"},
{file = "tokenizers-0.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e0480c452217edd35eca56fafe2029fb4d368b7c0475f8dfa3c5c9c400a7456"},
{file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a33ab881c8fe70474980577e033d0bc9a27b7ab8272896e500708b212995d834"},
{file = "tokenizers-0.15.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a308a607ca9de2c64c1b9ba79ec9a403969715a1b8ba5f998a676826f1a7039d"},
{file = "tokenizers-0.15.2-cp312-none-win32.whl", hash = "sha256:b8fcfa81bcb9447df582c5bc96a031e6df4da2a774b8080d4f02c0c16b42be0b"},
{file = "tokenizers-0.15.2-cp312-none-win_amd64.whl", hash = "sha256:38d7ab43c6825abfc0b661d95f39c7f8af2449364f01d331f3b51c94dcff7221"},
{file = "tokenizers-0.15.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:38bfb0204ff3246ca4d5e726e8cc8403bfc931090151e6eede54d0e0cf162ef0"},
{file = "tokenizers-0.15.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c861d35e8286a53e06e9e28d030b5a05bcbf5ac9d7229e561e53c352a85b1fc"},
{file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:936bf3842db5b2048eaa53dade907b1160f318e7c90c74bfab86f1e47720bdd6"},
{file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:620beacc3373277700d0e27718aa8b25f7b383eb8001fba94ee00aeea1459d89"},
{file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2735ecbbf37e52db4ea970e539fd2d450d213517b77745114f92867f3fc246eb"},
{file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:473c83c5e2359bb81b0b6fde870b41b2764fcdd36d997485e07e72cc3a62264a"},
{file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:968fa1fb3c27398b28a4eca1cbd1e19355c4d3a6007f7398d48826bbe3a0f728"},
{file = "tokenizers-0.15.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:865c60ae6eaebdde7da66191ee9b7db52e542ed8ee9d2c653b6d190a9351b980"},
{file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7c0d8b52664ab2d4a8d6686eb5effc68b78608a9008f086a122a7b2996befbab"},
{file = "tokenizers-0.15.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f33dfbdec3784093a9aebb3680d1f91336c56d86cc70ddf88708251da1fe9064"},
{file = "tokenizers-0.15.2-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:d44ba80988ff9424e33e0a49445072ac7029d8c0e1601ad25a0ca5f41ed0c1d6"},
{file = "tokenizers-0.15.2-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:dce74266919b892f82b1b86025a613956ea0ea62a4843d4c4237be2c5498ed3a"},
{file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0ef06b9707baeb98b316577acb04f4852239d856b93e9ec3a299622f6084e4be"},
{file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c73e2e74bbb07910da0d37c326869f34113137b23eadad3fc00856e6b3d9930c"},
{file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4eeb12daf02a59e29f578a865f55d87cd103ce62bd8a3a5874f8fdeaa82e336b"},
{file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ba9f6895af58487ca4f54e8a664a322f16c26bbb442effd01087eba391a719e"},
{file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccec77aa7150e38eec6878a493bf8c263ff1fa8a62404e16c6203c64c1f16a26"},
{file = "tokenizers-0.15.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3f40604f5042ff210ba82743dda2b6aa3e55aa12df4e9f2378ee01a17e2855e"},
{file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5645938a42d78c4885086767c70923abad047163d809c16da75d6b290cb30bbe"},
{file = "tokenizers-0.15.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:05a77cbfebe28a61ab5c3891f9939cc24798b63fa236d84e5f29f3a85a200c00"},
{file = "tokenizers-0.15.2-cp37-none-win32.whl", hash = "sha256:361abdc068e8afe9c5b818769a48624687fb6aaed49636ee39bec4e95e1a215b"},
{file = "tokenizers-0.15.2-cp37-none-win_amd64.whl", hash = "sha256:7ef789f83eb0f9baeb4d09a86cd639c0a5518528f9992f38b28e819df397eb06"},
{file = "tokenizers-0.15.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4fe1f74a902bee74a3b25aff180fbfbf4f8b444ab37c4d496af7afd13a784ed2"},
{file = "tokenizers-0.15.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c4b89038a684f40a6b15d6b09f49650ac64d951ad0f2a3ea9169687bbf2a8ba"},
{file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:d05a1b06f986d41aed5f2de464c003004b2df8aaf66f2b7628254bcbfb72a438"},
{file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508711a108684111ec8af89d3a9e9e08755247eda27d0ba5e3c50e9da1600f6d"},
{file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:daa348f02d15160cb35439098ac96e3a53bacf35885072611cd9e5be7d333daa"},
{file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:494fdbe5932d3416de2a85fc2470b797e6f3226c12845cadf054dd906afd0442"},
{file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c2d60f5246f4da9373f75ff18d64c69cbf60c3bca597290cea01059c336d2470"},
{file = "tokenizers-0.15.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93268e788825f52de4c7bdcb6ebc1fcd4a5442c02e730faa9b6b08f23ead0e24"},
{file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6fc7083ab404019fc9acafe78662c192673c1e696bd598d16dc005bd663a5cf9"},
{file = "tokenizers-0.15.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e39b41e5531d6b2122a77532dbea60e171ef87a3820b5a3888daa847df4153"},
{file = "tokenizers-0.15.2-cp38-none-win32.whl", hash = "sha256:06cd0487b1cbfabefb2cc52fbd6b1f8d4c37799bd6c6e1641281adaa6b2504a7"},
{file = "tokenizers-0.15.2-cp38-none-win_amd64.whl", hash = "sha256:5179c271aa5de9c71712e31cb5a79e436ecd0d7532a408fa42a8dbfa4bc23fd9"},
{file = "tokenizers-0.15.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:82f8652a74cc107052328b87ea8b34291c0f55b96d8fb261b3880216a9f9e48e"},
{file = "tokenizers-0.15.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:02458bee6f5f3139f1ebbb6d042b283af712c0981f5bc50edf771d6b762d5e4f"},
{file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c9a09cd26cca2e1c349f91aa665309ddb48d71636370749414fbf67bc83c5343"},
{file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:158be8ea8554e5ed69acc1ce3fbb23a06060bd4bbb09029431ad6b9a466a7121"},
{file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ddba9a2b0c8c81633eca0bb2e1aa5b3a15362b1277f1ae64176d0f6eba78ab1"},
{file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ef5dd1d39797044642dbe53eb2bc56435308432e9c7907728da74c69ee2adca"},
{file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:454c203164e07a860dbeb3b1f4a733be52b0edbb4dd2e5bd75023ffa8b49403a"},
{file = "tokenizers-0.15.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cf6b7f1d4dc59af960e6ffdc4faffe6460bbfa8dce27a58bf75755ffdb2526d"},
{file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2ef09bbc16519f6c25d0c7fc0c6a33a6f62923e263c9d7cca4e58b8c61572afb"},
{file = "tokenizers-0.15.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c9a2ebdd2ad4ec7a68e7615086e633857c85e2f18025bd05d2a4399e6c5f7169"},
{file = "tokenizers-0.15.2-cp39-none-win32.whl", hash = "sha256:918fbb0eab96fe08e72a8c2b5461e9cce95585d82a58688e7f01c2bd546c79d0"},
{file = "tokenizers-0.15.2-cp39-none-win_amd64.whl", hash = "sha256:524e60da0135e106b254bd71f0659be9f89d83f006ea9093ce4d1fab498c6d0d"},
{file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:6a9b648a58281c4672212fab04e60648fde574877d0139cd4b4f93fe28ca8944"},
{file = "tokenizers-0.15.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7c7d18b733be6bbca8a55084027f7be428c947ddf871c500ee603e375013ffba"},
{file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:13ca3611de8d9ddfbc4dc39ef54ab1d2d4aaa114ac8727dfdc6a6ec4be017378"},
{file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:237d1bf3361cf2e6463e6c140628e6406766e8b27274f5fcc62c747ae3c6f094"},
{file = "tokenizers-0.15.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67a0fe1e49e60c664915e9fb6b0cb19bac082ab1f309188230e4b2920230edb3"},
{file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4e022fe65e99230b8fd89ebdfea138c24421f91c1a4f4781a8f5016fd5cdfb4d"},
{file = "tokenizers-0.15.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:d857be2df69763362ac699f8b251a8cd3fac9d21893de129bc788f8baaef2693"},
{file = "tokenizers-0.15.2-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:708bb3e4283177236309e698da5fcd0879ce8fd37457d7c266d16b550bcbbd18"},
{file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c35e09e9899b72a76e762f9854e8750213f67567787d45f37ce06daf57ca78"},
{file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1257f4394be0d3b00de8c9e840ca5601d0a4a8438361ce9c2b05c7d25f6057b"},
{file = "tokenizers-0.15.2-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02272fe48280e0293a04245ca5d919b2c94a48b408b55e858feae9618138aeda"},
{file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dc3ad9ebc76eabe8b1d7c04d38be884b8f9d60c0cdc09b0aa4e3bcf746de0388"},
{file = "tokenizers-0.15.2-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:32e16bdeffa7c4f46bf2152172ca511808b952701d13e7c18833c0b73cb5c23f"},
{file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fb16ba563d59003028b678d2361a27f7e4ae0ab29c7a80690efa20d829c81fdb"},
{file = "tokenizers-0.15.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:2277c36d2d6cdb7876c274547921a42425b6810d38354327dd65a8009acf870c"},
{file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1cf75d32e8d250781940d07f7eece253f2fe9ecdb1dc7ba6e3833fa17b82fcbc"},
{file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b3b31884dc8e9b21508bb76da80ebf7308fdb947a17affce815665d5c4d028"},
{file = "tokenizers-0.15.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10122d8d8e30afb43bb1fe21a3619f62c3e2574bff2699cf8af8b0b6c5dc4a3"},
{file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d88b96ff0fe8e91f6ef01ba50b0d71db5017fa4e3b1d99681cec89a85faf7bf7"},
{file = "tokenizers-0.15.2-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:37aaec5a52e959892870a7c47cef80c53797c0db9149d458460f4f31e2fb250e"},
{file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e2ea752f2b0fe96eb6e2f3adbbf4d72aaa1272079b0dfa1145507bd6a5d537e6"},
{file = "tokenizers-0.15.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:4b19a808d8799fda23504a5cd31d2f58e6f52f140380082b352f877017d6342b"},
{file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:64c86e5e068ac8b19204419ed8ca90f9d25db20578f5881e337d203b314f4104"},
{file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de19c4dc503c612847edf833c82e9f73cd79926a384af9d801dcf93f110cea4e"},
{file = "tokenizers-0.15.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ea09acd2fe3324174063d61ad620dec3bcf042b495515f27f638270a7d466e8b"},
{file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:cf27fd43472e07b57cf420eee1e814549203d56de00b5af8659cb99885472f1f"},
{file = "tokenizers-0.15.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:7ca22bd897537a0080521445d91a58886c8c04084a6a19e6c78c586e0cfa92a5"},
{file = "tokenizers-0.15.2.tar.gz", hash = "sha256:e6e9c6e019dd5484be5beafc775ae6c925f4c69a3487040ed09b45e13df2cb91"},
]
[package.dependencies]
huggingface_hub = ">=0.16.4,<1.0"
[package.extras]
dev = ["tokenizers[testing]"]
docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"]
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
[[package]]
name = "tomli"
version = "2.0.1"
@ -923,6 +1143,26 @@ files = [
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
[[package]]
name = "tqdm"
version = "4.66.2"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
{file = "tqdm-4.66.2-py3-none-any.whl", hash = "sha256:1ee4f8a893eb9bef51c6e35730cebf234d5d0b6bd112b0271e10ed7c24a02bd9"},
{file = "tqdm-4.66.2.tar.gz", hash = "sha256:6cd52cdf0fef0e0f543299cfc96fec90d7b8a7e88745f411ec33eb44d5ed3531"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "typing-extensions"
version = "4.10.0"
@ -1010,4 +1250,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "446d53423f89a3a378db9cff7ce8cb392e146e294d31e9d1bbfc23108a571097"
content-hash = "892d39e8c2a50b049e3c2bb8f39b8a2126273d9644944456bd9900f6d15c5cb2"

@ -9,7 +9,7 @@ readme = "README.md"
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.28"
langchain-text-splitters = "^0.0.1"
ai21 = "^2.1.2"
ai21 = "^2.2.1"
[tool.poetry.group.test]
optional = true

@ -1,39 +1,80 @@
"""Test ChatAI21 chat model."""
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.outputs import ChatGeneration
from langchain_ai21.chat_models import ChatAI21
_MODEL_NAME = "j2-ultra"
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
def test_invoke() -> None:
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
],
)
def test_invoke(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=_MODEL_NAME)
llm = ChatAI21(model=model)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
def test_generation() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=_MODEL_NAME)
message = HumanMessage(content="Hello")
@pytest.mark.parametrize(
ids=[
"when_j2_model_num_results_is_1",
"when_j2_model_num_results_is_3",
"when_jamba_model_n_is_1",
"when_jamba_model_n_is_3",
],
argnames=["model", "num_results"],
argvalues=[
(J2_CHAT_MODEL_NAME, 1),
(J2_CHAT_MODEL_NAME, 3),
(JAMBA_CHAT_MODEL_NAME, 1),
(JAMBA_CHAT_MODEL_NAME, 3),
],
)
def test_generation(model: str, num_results: int) -> None:
"""Test generation with multiple models and different result counts."""
# Determine the configuration key based on the model type
config_key = "n" if model == JAMBA_CHAT_MODEL_NAME else "num_results"
result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
# Create the model instance using the appropriate key for the result count
llm = ChatAI21(model=model, **{config_key: num_results})
message = HumanMessage(content="Hello, this is a test. Can you help me please?")
result = llm.generate([[message]], config=dict(tags=["foo"]))
for generations in result.generations:
assert len(generations) == 1
assert len(generations) == num_results
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
async def test_ageneration() -> None:
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
],
argnames=["model"],
argvalues=[
(J2_CHAT_MODEL_NAME,),
(JAMBA_CHAT_MODEL_NAME,),
],
)
async def test_ageneration(model: str) -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model=_MODEL_NAME)
llm = ChatAI21(model=model)
message = HumanMessage(content="Hello")
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))

@ -1,5 +1,6 @@
"""Standard LangChain interface tests"""
import time
from typing import Type
import pytest
@ -9,17 +10,53 @@ from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
from langchain_ai21 import ChatAI21
class TestAI21Standard(ChatModelIntegrationTests):
class TestAI21J2(ChatModelIntegrationTests):
def teardown(self) -> None:
# avoid getting rate limited
time.sleep(1)
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
def test_stream(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_stream(
chat_model_class,
chat_model_params,
)
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
async def test_astream(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
await super().test_astream(
chat_model_class,
chat_model_params,
)
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "j2-ultra",
}
class TestAI21Jamba(ChatModelIntegrationTests):
def teardown(self) -> None:
# avoid getting rate limited
time.sleep(1)
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.mark.xfail(reason="Emits AIMessage instead of AIMessageChunk.")
def test_stream(
self,
@ -41,3 +78,9 @@ class TestAI21Standard(ChatModelIntegrationTests):
chat_model_class,
chat_model_params,
)
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "jamba-instruct-preview",
}

@ -0,0 +1,9 @@
import pytest
from langchain_ai21.chat.chat_adapter import ChatAdapter
from langchain_ai21.chat.chat_factory import create_chat_adapter
@pytest.fixture
def chat_adapter(model: str) -> ChatAdapter:
return create_chat_adapter(model)

@ -0,0 +1,195 @@
from typing import List
import pytest
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models import RoleType
from ai21.models.chat import ChatMessage
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.messages import (
ChatMessage as LangChainChatMessage,
)
from langchain_ai21.chat.chat_adapter import ChatAdapter
_J2_MODEL_NAME = "j2-ultra"
_JAMBA_MODEL_NAME = "jamba-instruct-preview"
@pytest.mark.parametrize(
ids=[
"when_human_message_j2_model",
"when_ai_message_j2_model",
"when_human_message_jamba_model",
"when_ai_message_jamba_model",
],
argnames=["model", "message", "expected_ai21_message"],
argvalues=[
(
_J2_MODEL_NAME,
HumanMessage(content="Human Message Content"),
J2ChatMessage(role=RoleType.USER, text="Human Message Content"),
),
(
_J2_MODEL_NAME,
AIMessage(content="AI Message Content"),
J2ChatMessage(role=RoleType.ASSISTANT, text="AI Message Content"),
),
(
_JAMBA_MODEL_NAME,
HumanMessage(content="Human Message Content"),
ChatMessage(role=RoleType.USER, content="Human Message Content"),
),
(
_JAMBA_MODEL_NAME,
AIMessage(content="AI Message Content"),
ChatMessage(role=RoleType.ASSISTANT, content="AI Message Content"),
),
],
)
def test_convert_message_to_ai21_message(
message: BaseMessage,
expected_ai21_message: ChatMessage,
chat_adapter: ChatAdapter,
) -> None:
ai21_message = chat_adapter._convert_message_to_ai21_message(message)
assert ai21_message == expected_ai21_message
@pytest.mark.parametrize(
ids=[
"when_system_message_j2_model",
"when_langchain_chat_message_j2_model",
],
argnames=["model", "message"],
argvalues=[
(
_J2_MODEL_NAME,
SystemMessage(content="System Message Content"),
),
(
_J2_MODEL_NAME,
LangChainChatMessage(content="Chat Message Content", role="human"),
),
],
)
def test_convert_message_to_ai21_message__when_invalid_role__should_raise_exception(
message: BaseMessage,
chat_adapter: ChatAdapter,
) -> None:
with pytest.raises(ValueError) as e:
chat_adapter._convert_message_to_ai21_message(message)
assert e.value.args[0] == (
f"Could not resolve role type from message {message}. "
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
)
@pytest.mark.parametrize(
ids=[
"when_all_messages_are_human_messages__should_return_system_none_j2_model",
"when_first_message_is_system__should_return_system_j2_model",
"when_all_messages_are_human_messages__should_return_system_none_jamba_model",
"when_first_message_is_system__should_return_system_jamba_model",
],
argnames=["model", "messages", "expected_messages"],
argvalues=[
(
_J2_MODEL_NAME,
[
HumanMessage(content="Human Message Content 1"),
HumanMessage(content="Human Message Content 2"),
],
{
"system": "",
"messages": [
J2ChatMessage(
role=RoleType.USER,
text="Human Message Content 1",
),
J2ChatMessage(
role=RoleType.USER,
text="Human Message Content 2",
),
],
},
),
(
_J2_MODEL_NAME,
[
SystemMessage(content="System Message Content 1"),
HumanMessage(content="Human Message Content 1"),
],
{
"system": "System Message Content 1",
"messages": [
J2ChatMessage(
role=RoleType.USER,
text="Human Message Content 1",
),
],
},
),
(
_JAMBA_MODEL_NAME,
[
HumanMessage(content="Human Message Content 1"),
HumanMessage(content="Human Message Content 2"),
],
{
"messages": [
ChatMessage(
role=RoleType.USER,
content="Human Message Content 1",
),
ChatMessage(
role=RoleType.USER,
content="Human Message Content 2",
),
]
},
),
(
_JAMBA_MODEL_NAME,
[
SystemMessage(content="System Message Content 1"),
HumanMessage(content="Human Message Content 1"),
],
{
"messages": [
ChatMessage(role="system", content="System Message Content 1"),
ChatMessage(role="user", content="Human Message Content 1"),
],
},
),
],
)
def test_convert_messages(
chat_adapter: ChatAdapter,
messages: List[BaseMessage],
expected_messages: List[ChatMessage],
) -> None:
converted_messages = chat_adapter.convert_messages(messages)
assert converted_messages == expected_messages
@pytest.mark.parametrize(
ids=[
"when_j2_model",
],
argnames=["model"],
argvalues=[
(_J2_MODEL_NAME,),
],
)
def test_convert_messages__when_system_is_not_first(chat_adapter: ChatAdapter) -> None:
messages = [
HumanMessage(content="Human Message Content 1"),
SystemMessage(content="System Message Content 1"),
]
with pytest.raises(ValueError):
chat_adapter.convert_messages(messages)

@ -0,0 +1,34 @@
from typing import Type
import pytest
from langchain_ai21.chat.chat_adapter import (
ChatAdapter,
J2ChatAdapter,
JambaChatCompletionsAdapter,
)
from langchain_ai21.chat.chat_factory import create_chat_adapter
from tests.unit_tests.conftest import J2_CHAT_MODEL_NAME, JAMBA_CHAT_MODEL_NAME
@pytest.mark.parametrize(
ids=[
"when_j2_model",
"when_jamba_model",
],
argnames=["model", "expected_chat_type"],
argvalues=[
(J2_CHAT_MODEL_NAME, J2ChatAdapter),
(JAMBA_CHAT_MODEL_NAME, JambaChatCompletionsAdapter),
],
)
def test_create_chat_adapter_with_supported_models(
model: str, expected_chat_type: Type[ChatAdapter]
) -> None:
adapter = create_chat_adapter(model)
assert isinstance(adapter, expected_chat_type)
def test_create_chat_adapter__when_model_not_supported() -> None:
with pytest.raises(ValueError):
create_chat_adapter("unsupported-model")

@ -21,6 +21,8 @@ from ai21.models import (
from ai21.models.responses.segmentation_response import Segment
from pytest_mock import MockerFixture
J2_CHAT_MODEL_NAME = "j2-ultra"
JAMBA_CHAT_MODEL_NAME = "jamba-instruct-preview"
DUMMY_API_KEY = "test_api_key"
BASIC_EXAMPLE_LLM_PARAMETERS = {
@ -39,6 +41,23 @@ BASIC_EXAMPLE_LLM_PARAMETERS = {
),
}
BASIC_EXAMPLE_CHAT_PARAMETERS = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
),
"n": 3,
}
SEGMENTS = [
Segment(
segment_type="normal_text",
@ -82,6 +101,23 @@ BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT = {
).to_dict(),
}
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT = {
"num_results": 3,
"max_tokens": 20,
"min_tokens": 10,
"temperature": 0.5,
"top_p": 0.5,
"top_k_return": 0,
"frequency_penalty": Penalty(scale=0.2, apply_to_numbers=True).to_dict(),
"presence_penalty": Penalty(scale=0.2, apply_to_stopwords=True).to_dict(),
"count_penalty": Penalty(
scale=0.2,
apply_to_punctuation=True,
apply_to_emojis=True,
).to_dict(),
"n": 3,
}
@pytest.fixture
def mocked_completion_response(mocker: MockerFixture) -> Mock:

@ -1,5 +1,5 @@
"""Test chat model integration."""
from typing import List, Optional, cast
from typing import cast
from unittest.mock import Mock, call
import pytest
@ -7,24 +7,18 @@ from ai21 import MissingApiKeyError
from ai21.models import ChatMessage, Penalty, RoleType
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.messages import (
ChatMessage as LangChainChatMessage,
)
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_ai21.chat_models import (
ChatAI21,
_convert_message_to_ai21_message,
_convert_messages_to_ai21_messages,
)
from tests.unit_tests.conftest import (
BASIC_EXAMPLE_LLM_PARAMETERS,
BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
BASIC_EXAMPLE_CHAT_PARAMETERS,
BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
DUMMY_API_KEY,
temporarily_unset_api_key,
)
@ -43,7 +37,7 @@ def test_initialization__when_default_parameters_in_init() -> None:
def test_initialization__when_custom_parameters_in_init() -> None:
model = "j2-mid"
model = "j2-ultra"
num_results = 1
max_tokens = 10
min_tokens = 20
@ -79,101 +73,6 @@ def test_initialization__when_custom_parameters_in_init() -> None:
assert count_penalty == count_penalty
@pytest.mark.parametrize(
ids=[
"when_human_message",
"when_ai_message",
],
argnames=["message", "expected_ai21_message"],
argvalues=[
(
HumanMessage(content="Human Message Content"),
ChatMessage(role=RoleType.USER, text="Human Message Content"),
),
(
AIMessage(content="AI Message Content"),
ChatMessage(role=RoleType.ASSISTANT, text="AI Message Content"),
),
],
)
def test_convert_message_to_ai21_message(
message: BaseMessage, expected_ai21_message: ChatMessage
) -> None:
ai21_message = _convert_message_to_ai21_message(message)
assert ai21_message == expected_ai21_message
@pytest.mark.parametrize(
ids=[
"when_system_message",
"when_langchain_chat_message",
],
argnames=["message"],
argvalues=[
(SystemMessage(content="System Message Content"),),
(LangChainChatMessage(content="Chat Message Content", role="human"),),
],
)
def test_convert_message_to_ai21_message__when_invalid_role__should_raise_exception(
message: BaseMessage,
) -> None:
with pytest.raises(ValueError) as e:
_convert_message_to_ai21_message(message)
assert e.value.args[0] == (
f"Could not resolve role type from message {message}. "
f"Only support {HumanMessage.__name__} and {AIMessage.__name__}."
)
@pytest.mark.parametrize(
ids=[
"when_all_messages_are_human_messages__should_return_system_none",
"when_first_message_is_system__should_return_system",
],
argnames=["messages", "expected_system", "expected_messages"],
argvalues=[
(
[
HumanMessage(content="Human Message Content 1"),
HumanMessage(content="Human Message Content 2"),
],
None,
[
ChatMessage(role=RoleType.USER, text="Human Message Content 1"),
ChatMessage(role=RoleType.USER, text="Human Message Content 2"),
],
),
(
[
SystemMessage(content="System Message Content 1"),
HumanMessage(content="Human Message Content 1"),
],
"System Message Content 1",
[
ChatMessage(role=RoleType.USER, text="Human Message Content 1"),
],
),
],
)
def test_convert_messages(
messages: List[BaseMessage],
expected_system: Optional[str],
expected_messages: List[ChatMessage],
) -> None:
system, ai21_messages = _convert_messages_to_ai21_messages(messages)
assert ai21_messages == expected_messages
assert system == expected_system
def test_convert_messages_when_system_is_not_first__should_raise_value_error() -> None:
messages = [
HumanMessage(content="Human Message Content 1"),
SystemMessage(content="System Message Content 1"),
]
with pytest.raises(ValueError):
_convert_messages_to_ai21_messages(messages)
def test_invoke(mock_client_with_chat: Mock) -> None:
chat_input = "I'm Pickle Rick"
@ -181,7 +80,7 @@ def test_invoke(mock_client_with_chat: Mock) -> None:
model="j2-ultra",
api_key=DUMMY_API_KEY,
client=mock_client_with_chat,
**BASIC_EXAMPLE_LLM_PARAMETERS,
**BASIC_EXAMPLE_CHAT_PARAMETERS,
)
llm.invoke(input=chat_input, config=dict(tags=["foo"]), stop=["\n"])
@ -190,7 +89,7 @@ def test_invoke(mock_client_with_chat: Mock) -> None:
messages=[ChatMessage(role=RoleType.USER, text=chat_input)],
system="",
stop_sequences=["\n"],
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
)
@ -207,7 +106,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
llm = ChatAI21(
model="j2-ultra",
client=mock_client_with_chat,
**BASIC_EXAMPLE_LLM_PARAMETERS,
**BASIC_EXAMPLE_CHAT_PARAMETERS,
)
llm.generate(messages=[messages0, messages1])
@ -226,7 +125,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
ChatMessage(role=RoleType.USER, text=str(messages0[2].content)),
],
system="",
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
),
call(
model="j2-ultra",
@ -234,7 +133,7 @@ def test_generate(mock_client_with_chat: Mock) -> None:
ChatMessage(role=RoleType.USER, text=str(messages1[1].content)),
],
system="system message",
**BASIC_EXAMPLE_LLM_PARAMETERS_AS_DICT,
**BASIC_EXAMPLE_CHAT_PARAMETERS_AS_DICT,
),
]
)

@ -9,7 +9,7 @@ from langchain_standard_tests.unit_tests import ChatModelUnitTests
from langchain_ai21 import ChatAI21
class TestAI21Standard(ChatModelUnitTests):
class TestAI21J2(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@ -20,3 +20,16 @@ class TestAI21Standard(ChatModelUnitTests):
"model": "j2-ultra",
"api_key": "test_api_key",
}
class TestAI21Jamba(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatAI21
@pytest.fixture
def chat_model_params(self) -> dict:
return {
"model": "jamba-instruct",
"api_key": "test_api_key",
}

Loading…
Cancel
Save