mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
4d7f6fa968
Description: Added support for batching when using AI21 Embeddings model Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
105 lines
2.5 KiB
Python
105 lines
2.5 KiB
Python
"""Test AI21LLM llm."""
|
|
|
|
from langchain_ai21.llms import AI21LLM
|
|
|
|
_MODEL_NAME = "j2-mid"
|
|
|
|
|
|
def _generate_llm() -> AI21LLM:
|
|
"""
|
|
Testing AI21LLm using non default parameters with the following parameters
|
|
"""
|
|
return AI21LLM(
|
|
model=_MODEL_NAME,
|
|
max_tokens=2, # Use less tokens for a faster response
|
|
temperature=0, # for a consistent response
|
|
epoch=1,
|
|
)
|
|
|
|
|
|
def test_stream() -> None:
|
|
"""Test streaming tokens from AI21."""
|
|
llm = AI21LLM(
|
|
model=_MODEL_NAME,
|
|
)
|
|
|
|
for token in llm.stream("I'm Pickle Rick"):
|
|
assert isinstance(token, str)
|
|
|
|
|
|
async def test_abatch() -> None:
|
|
"""Test streaming tokens from AI21LLM."""
|
|
llm = AI21LLM(
|
|
model=_MODEL_NAME,
|
|
)
|
|
|
|
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
|
for token in result:
|
|
assert isinstance(token, str)
|
|
|
|
|
|
async def test_abatch_tags() -> None:
|
|
"""Test batch tokens from AI21LLM."""
|
|
llm = AI21LLM(
|
|
model=_MODEL_NAME,
|
|
)
|
|
|
|
result = await llm.abatch(
|
|
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
|
)
|
|
for token in result:
|
|
assert isinstance(token, str)
|
|
|
|
|
|
def test_batch() -> None:
|
|
"""Test batch tokens from AI21LLM."""
|
|
llm = AI21LLM(
|
|
model=_MODEL_NAME,
|
|
)
|
|
|
|
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
|
for token in result:
|
|
assert isinstance(token, str)
|
|
|
|
|
|
async def test_ainvoke() -> None:
|
|
"""Test invoke tokens from AI21LLM."""
|
|
llm = AI21LLM(
|
|
model=_MODEL_NAME,
|
|
)
|
|
|
|
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
|
assert isinstance(result, str)
|
|
|
|
|
|
def test_invoke() -> None:
|
|
"""Test invoke tokens from AI21LLM."""
|
|
llm = AI21LLM(
|
|
model=_MODEL_NAME,
|
|
)
|
|
|
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
|
assert isinstance(result, str)
|
|
|
|
|
|
def test__generate() -> None:
|
|
llm = _generate_llm()
|
|
llm_result = llm.generate(
|
|
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
|
|
stop=["##"],
|
|
)
|
|
|
|
assert len(llm_result.generations) > 0
|
|
assert llm_result.llm_output["token_count"] != 0 # type: ignore
|
|
|
|
|
|
async def test__agenerate() -> None:
|
|
llm = _generate_llm()
|
|
llm_result = await llm.agenerate(
|
|
prompts=["Hey there, my name is Pickle Rick. What is your name?"],
|
|
stop=["##"],
|
|
)
|
|
|
|
assert len(llm_result.generations) > 0
|
|
assert llm_result.llm_output["token_count"] != 0 # type: ignore
|