mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
4d7f6fa968
Description: Added support for batching when using AI21 Embeddings model Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
"""Test ChatAI21 chat model."""
|
|
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.outputs import ChatGeneration
|
|
|
|
from langchain_ai21.chat_models import ChatAI21
|
|
|
|
_MODEL_NAME = "j2-ultra"
|
|
|
|
|
|
def test_invoke() -> None:
|
|
"""Test invoke tokens from AI21."""
|
|
llm = ChatAI21(model=_MODEL_NAME)
|
|
|
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
|
assert isinstance(result.content, str)
|
|
|
|
|
|
def test_generation() -> None:
|
|
"""Test invoke tokens from AI21."""
|
|
llm = ChatAI21(model=_MODEL_NAME)
|
|
message = HumanMessage(content="Hello")
|
|
|
|
result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
|
|
|
|
for generations in result.generations:
|
|
assert len(generations) == 1
|
|
for generation in generations:
|
|
assert isinstance(generation, ChatGeneration)
|
|
assert isinstance(generation.text, str)
|
|
assert generation.text == generation.message.content
|
|
|
|
|
|
async def test_ageneration() -> None:
|
|
"""Test invoke tokens from AI21."""
|
|
llm = ChatAI21(model=_MODEL_NAME)
|
|
message = HumanMessage(content="Hello")
|
|
|
|
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
|
|
|
|
for generations in result.generations:
|
|
assert len(generations) == 1
|
|
for generation in generations:
|
|
assert isinstance(generation, ChatGeneration)
|
|
assert isinstance(generation.text, str)
|
|
assert generation.text == generation.message.content
|