ai21[patch]: AI21 Labs Batch Support in Embeddings (#18633)

Description: Added support for batching when using AI21 Embeddings model
Twitter handle: https://github.com/AI21Labs

---------

Co-authored-by: Asaf Gardin <asafg@ai21.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Asaf Joseph Gardin 2024-03-15 01:10:23 +02:00 committed by GitHub
parent 321db89e87
commit 4d7f6fa968
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 147 additions and 23 deletions

View File

@ -1,10 +1,18 @@
from typing import Any, List
from itertools import islice
from typing import Any, Iterator, List, Optional
from ai21.models import EmbedType
from langchain_core.embeddings import Embeddings
from langchain_ai21.ai21_base import AI21Base
_DEFAULT_BATCH_SIZE = 128
def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
texts_itr = iter(texts)
return iter(lambda: list(islice(texts_itr, batch_size)), [])
class AI21Embeddings(Embeddings, AI21Base):
"""AI21 Embeddings embedding model.
@ -20,22 +28,52 @@ class AI21Embeddings(Embeddings, AI21Base):
query_result = embeddings.embed_query("Hello embeddings world!")
"""
def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
batch_size: int = _DEFAULT_BATCH_SIZE
"""Maximum number of texts to embed in each batch"""
def embed_documents(
self,
texts: List[str],
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[List[float]]:
"""Embed search docs."""
response = self.client.embed.create(
return self._send_embeddings(
texts=texts,
type=EmbedType.SEGMENT,
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.SEGMENT,
**kwargs,
)
return [result.embedding for result in response.results]
def embed_query(self, text: str, **kwargs: Any) -> List[float]:
def embed_query(
self,
text: str,
*,
batch_size: Optional[int] = None,
**kwargs: Any,
) -> List[float]:
"""Embed query text."""
response = self.client.embed.create(
return self._send_embeddings(
texts=[text],
type=EmbedType.QUERY,
batch_size=batch_size or self.batch_size,
embed_type=EmbedType.QUERY,
**kwargs,
)[0]
def _send_embeddings(
self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any
) -> List[List[float]]:
chunks = _split_texts_into_batches(texts, batch_size)
responses = [
self.client.embed.create(
texts=chunk,
type=embed_type,
**kwargs,
)
for chunk in chunks
]
return [result.embedding for result in response.results][0]
return [
result.embedding for response in responses for result in response.results
]

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-ai21"
version = "0.1.0"
version = "0.1.1"
description = "An integration package connecting AI21 and LangChain"
authors = []
readme = "README.md"

View File

@ -1,13 +1,16 @@
"""Test ChatAI21 chat model."""
from langchain_core.messages import HumanMessage
from langchain_core.outputs import ChatGeneration
from langchain_ai21.chat_models import ChatAI21
_MODEL_NAME = "j2-ultra"
def test_invoke() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
@ -15,7 +18,7 @@ def test_invoke() -> None:
def test_generation() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)
message = HumanMessage(content="Hello")
result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
@ -30,7 +33,7 @@ def test_generation() -> None:
async def test_ageneration() -> None:
"""Test invoke tokens from AI21."""
llm = ChatAI21(model="j2-ultra")
llm = ChatAI21(model=_MODEL_NAME)
message = HumanMessage(content="Hello")
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))

View File

@ -1,4 +1,5 @@
"""Test AI21 embeddings."""
from langchain_ai21.embeddings import AI21Embeddings
@ -17,3 +18,20 @@ def test_langchain_ai21_embedding_query() -> None:
embedding = AI21Embeddings()
output = embedding.embed_query(document)
assert len(output) > 0
def test_langchain_ai21_embedding_documents__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = ["foo", "bar"]
embedding = AI21Embeddings()
output = embedding.embed_documents(documents, batch_size=1)
assert len(output) == 2
assert len(output[0]) > 0
def test_langchain_ai21_embedding_query__with_explicit_chunk_size() -> None:
"""Test AI21 embeddings with chunk size passed as an argument."""
documents = "foo bar"
embedding = AI21Embeddings()
output = embedding.embed_query(documents, batch_size=1)
assert len(output) > 0

View File

@ -1,15 +1,16 @@
"""Test AI21LLM llm."""
from langchain_ai21.llms import AI21LLM
_MODEL_NAME = "j2-mid"
def _generate_llm() -> AI21LLM:
"""
Testing AI21LLm using non default parameters with the following parameters
"""
return AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
max_tokens=2, # Use less tokens for a faster response
temperature=0, # for a consistent response
epoch=1,
@ -19,7 +20,7 @@ def _generate_llm() -> AI21LLM:
def test_stream() -> None:
"""Test streaming tokens from AI21."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)
for token in llm.stream("I'm Pickle Rick"):
@ -29,7 +30,7 @@ def test_stream() -> None:
async def test_abatch() -> None:
"""Test streaming tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
@ -40,7 +41,7 @@ async def test_abatch() -> None:
async def test_abatch_tags() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)
result = await llm.abatch(
@ -53,7 +54,7 @@ async def test_abatch_tags() -> None:
def test_batch() -> None:
"""Test batch tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
@ -64,7 +65,7 @@ def test_batch() -> None:
async def test_ainvoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
@ -74,7 +75,7 @@ async def test_ainvoke() -> None:
def test_invoke() -> None:
"""Test invoke tokens from AI21LLM."""
llm = AI21LLM(
model="j2-ultra",
model=_MODEL_NAME,
)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))

View File

@ -1,4 +1,6 @@
"""Test embedding model integration."""
from typing import List
from unittest.mock import Mock
import pytest
@ -65,3 +67,36 @@ def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
texts=texts,
type=EmbedType.SEGMENT,
)
@pytest.mark.parametrize(
ids=[
"empty_texts",
"chunk_size_greater_than_texts_length",
"chunk_size_equal_to_texts_length",
"chunk_size_less_than_texts_length",
"chunk_size_one_with_multiple_texts",
"chunk_size_greater_than_texts_length",
],
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
argvalues=[
([], 3, 0),
(["text1", "text2", "text3"], 5, 1),
(["text1", "text2", "text3"], 3, 1),
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
(["text1", "text2", "text3"], 1, 3),
(["text1", "text2", "text3"], 10, 1),
],
)
def test_get_len_safe_embeddings(
mock_client_with_embeddings: Mock,
texts: List[str],
chunk_size: int,
expected_internal_embeddings_calls: int,
) -> None:
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
llm.embed_documents(texts=texts, batch_size=chunk_size)
assert (
mock_client_with_embeddings.embed.create.call_count
== expected_internal_embeddings_calls
)

View File

@ -0,0 +1,29 @@
from typing import List
import pytest
from langchain_ai21.embeddings import _split_texts_into_batches
@pytest.mark.parametrize(
ids=[
"when_chunk_size_is_2__should_return_3_chunks",
"when_texts_is_empty__should_return_empty_list",
"when_chunk_size_is_1__should_return_10_chunks",
],
argnames=["input_texts", "chunk_size", "expected_output"],
argvalues=[
(["a", "b", "c", "d", "e"], 2, [["a", "b"], ["c", "d"], ["e"]]),
([], 3, []),
(
["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
1,
[["1"], ["2"], ["3"], ["4"], ["5"], ["6"], ["7"], ["8"], ["9"], ["10"]],
),
],
)
def test_chunked_text_generator(
input_texts: List[str], chunk_size: int, expected_output: List[List[str]]
) -> None:
result = list(_split_texts_into_batches(input_texts, chunk_size))
assert result == expected_output