mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
ai21[patch]: AI21 Labs Batch Support in Embeddings (#18633)
Description: Added support for batching when using AI21 Embeddings model Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
321db89e87
commit
4d7f6fa968
@ -1,10 +1,18 @@
|
||||
from typing import Any, List
|
||||
from itertools import islice
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from ai21.models import EmbedType
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_ai21.ai21_base import AI21Base
|
||||
|
||||
_DEFAULT_BATCH_SIZE = 128
|
||||
|
||||
|
||||
def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
|
||||
texts_itr = iter(texts)
|
||||
return iter(lambda: list(islice(texts_itr, batch_size)), [])
|
||||
|
||||
|
||||
class AI21Embeddings(Embeddings, AI21Base):
|
||||
"""AI21 Embeddings embedding model.
|
||||
@ -20,22 +28,52 @@ class AI21Embeddings(Embeddings, AI21Base):
|
||||
query_result = embeddings.embed_query("Hello embeddings world!")
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: List[str], **kwargs: Any) -> List[List[float]]:
|
||||
batch_size: int = _DEFAULT_BATCH_SIZE
|
||||
"""Maximum number of texts to embed in each batch"""
|
||||
|
||||
def embed_documents(
|
||||
self,
|
||||
texts: List[str],
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[List[float]]:
|
||||
"""Embed search docs."""
|
||||
response = self.client.embed.create(
|
||||
return self._send_embeddings(
|
||||
texts=texts,
|
||||
type=EmbedType.SEGMENT,
|
||||
batch_size=batch_size or self.batch_size,
|
||||
embed_type=EmbedType.SEGMENT,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [result.embedding for result in response.results]
|
||||
|
||||
def embed_query(self, text: str, **kwargs: Any) -> List[float]:
|
||||
def embed_query(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
batch_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[float]:
|
||||
"""Embed query text."""
|
||||
response = self.client.embed.create(
|
||||
return self._send_embeddings(
|
||||
texts=[text],
|
||||
type=EmbedType.QUERY,
|
||||
batch_size=batch_size or self.batch_size,
|
||||
embed_type=EmbedType.QUERY,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
def _send_embeddings(
|
||||
self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any
|
||||
) -> List[List[float]]:
|
||||
chunks = _split_texts_into_batches(texts, batch_size)
|
||||
responses = [
|
||||
self.client.embed.create(
|
||||
texts=chunk,
|
||||
type=embed_type,
|
||||
**kwargs,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
return [result.embedding for result in response.results][0]
|
||||
return [
|
||||
result.embedding for response in responses for result in response.results
|
||||
]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-ai21"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1"
|
||||
description = "An integration package connecting AI21 and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
@ -1,13 +1,16 @@
|
||||
"""Test ChatAI21 chat model."""
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
from langchain_ai21.chat_models import ChatAI21
|
||||
|
||||
_MODEL_NAME = "j2-ultra"
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from AI21."""
|
||||
llm = ChatAI21(model="j2-ultra")
|
||||
llm = ChatAI21(model=_MODEL_NAME)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
||||
@ -15,7 +18,7 @@ def test_invoke() -> None:
|
||||
|
||||
def test_generation() -> None:
|
||||
"""Test invoke tokens from AI21."""
|
||||
llm = ChatAI21(model="j2-ultra")
|
||||
llm = ChatAI21(model=_MODEL_NAME)
|
||||
message = HumanMessage(content="Hello")
|
||||
|
||||
result = llm.generate([[message], [message]], config=dict(tags=["foo"]))
|
||||
@ -30,7 +33,7 @@ def test_generation() -> None:
|
||||
|
||||
async def test_ageneration() -> None:
|
||||
"""Test invoke tokens from AI21."""
|
||||
llm = ChatAI21(model="j2-ultra")
|
||||
llm = ChatAI21(model=_MODEL_NAME)
|
||||
message = HumanMessage(content="Hello")
|
||||
|
||||
result = await llm.agenerate([[message], [message]], config=dict(tags=["foo"]))
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test AI21 embeddings."""
|
||||
|
||||
from langchain_ai21.embeddings import AI21Embeddings
|
||||
|
||||
|
||||
@ -17,3 +18,20 @@ def test_langchain_ai21_embedding_query() -> None:
|
||||
embedding = AI21Embeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
def test_langchain_ai21_embedding_documents__with_explicit_chunk_size() -> None:
|
||||
"""Test AI21 embeddings with chunk size passed as an argument."""
|
||||
documents = ["foo", "bar"]
|
||||
embedding = AI21Embeddings()
|
||||
output = embedding.embed_documents(documents, batch_size=1)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_ai21_embedding_query__with_explicit_chunk_size() -> None:
|
||||
"""Test AI21 embeddings with chunk size passed as an argument."""
|
||||
documents = "foo bar"
|
||||
embedding = AI21Embeddings()
|
||||
output = embedding.embed_query(documents, batch_size=1)
|
||||
assert len(output) > 0
|
||||
|
@ -1,15 +1,16 @@
|
||||
"""Test AI21LLM llm."""
|
||||
|
||||
|
||||
from langchain_ai21.llms import AI21LLM
|
||||
|
||||
_MODEL_NAME = "j2-mid"
|
||||
|
||||
|
||||
def _generate_llm() -> AI21LLM:
|
||||
"""
|
||||
Testing AI21LLm using non default parameters with the following parameters
|
||||
"""
|
||||
return AI21LLM(
|
||||
model="j2-ultra",
|
||||
model=_MODEL_NAME,
|
||||
max_tokens=2, # Use less tokens for a faster response
|
||||
temperature=0, # for a consistent response
|
||||
epoch=1,
|
||||
@ -19,7 +20,7 @@ def _generate_llm() -> AI21LLM:
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from AI21."""
|
||||
llm = AI21LLM(
|
||||
model="j2-ultra",
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
@ -29,7 +30,7 @@ def test_stream() -> None:
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model="j2-ultra",
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
@ -40,7 +41,7 @@ async def test_abatch() -> None:
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model="j2-ultra",
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = await llm.abatch(
|
||||
@ -53,7 +54,7 @@ async def test_abatch_tags() -> None:
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model="j2-ultra",
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
@ -64,7 +65,7 @@ def test_batch() -> None:
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model="j2-ultra",
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
@ -74,7 +75,7 @@ async def test_ainvoke() -> None:
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from AI21LLM."""
|
||||
llm = AI21LLM(
|
||||
model="j2-ultra",
|
||||
model=_MODEL_NAME,
|
||||
)
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
@ -65,3 +67,36 @@ def test_embed_documents(mock_client_with_embeddings: Mock) -> None:
|
||||
texts=texts,
|
||||
type=EmbedType.SEGMENT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"empty_texts",
|
||||
"chunk_size_greater_than_texts_length",
|
||||
"chunk_size_equal_to_texts_length",
|
||||
"chunk_size_less_than_texts_length",
|
||||
"chunk_size_one_with_multiple_texts",
|
||||
"chunk_size_greater_than_texts_length",
|
||||
],
|
||||
argnames=["texts", "chunk_size", "expected_internal_embeddings_calls"],
|
||||
argvalues=[
|
||||
([], 3, 0),
|
||||
(["text1", "text2", "text3"], 5, 1),
|
||||
(["text1", "text2", "text3"], 3, 1),
|
||||
(["text1", "text2", "text3", "text4", "text5"], 2, 3),
|
||||
(["text1", "text2", "text3"], 1, 3),
|
||||
(["text1", "text2", "text3"], 10, 1),
|
||||
],
|
||||
)
|
||||
def test_get_len_safe_embeddings(
|
||||
mock_client_with_embeddings: Mock,
|
||||
texts: List[str],
|
||||
chunk_size: int,
|
||||
expected_internal_embeddings_calls: int,
|
||||
) -> None:
|
||||
llm = AI21Embeddings(client=mock_client_with_embeddings, api_key=DUMMY_API_KEY)
|
||||
llm.embed_documents(texts=texts, batch_size=chunk_size)
|
||||
assert (
|
||||
mock_client_with_embeddings.embed.create.call_count
|
||||
== expected_internal_embeddings_calls
|
||||
)
|
||||
|
29
libs/partners/ai21/tests/unit_tests/test_utils.py
Normal file
29
libs/partners/ai21/tests/unit_tests/test_utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_ai21.embeddings import _split_texts_into_batches
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"when_chunk_size_is_2__should_return_3_chunks",
|
||||
"when_texts_is_empty__should_return_empty_list",
|
||||
"when_chunk_size_is_1__should_return_10_chunks",
|
||||
],
|
||||
argnames=["input_texts", "chunk_size", "expected_output"],
|
||||
argvalues=[
|
||||
(["a", "b", "c", "d", "e"], 2, [["a", "b"], ["c", "d"], ["e"]]),
|
||||
([], 3, []),
|
||||
(
|
||||
["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||
1,
|
||||
[["1"], ["2"], ["3"], ["4"], ["5"], ["6"], ["7"], ["8"], ["9"], ["10"]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_chunked_text_generator(
|
||||
input_texts: List[str], chunk_size: int, expected_output: List[List[str]]
|
||||
) -> None:
|
||||
result = list(_split_texts_into_batches(input_texts, chunk_size))
|
||||
assert result == expected_output
|
Loading…
Reference in New Issue
Block a user