mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
4d7f6fa968
Description: Added support for batching when using AI21 Embeddings model Twitter handle: https://github.com/AI21Labs --------- Co-authored-by: Asaf Gardin <asafg@ai21.com> Co-authored-by: Erick Friis <erick@langchain.dev>
80 lines
2.2 KiB
Python
80 lines
2.2 KiB
Python
from itertools import islice
|
|
from typing import Any, Iterator, List, Optional
|
|
|
|
from ai21.models import EmbedType
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
from langchain_ai21.ai21_base import AI21Base
|
|
|
|
_DEFAULT_BATCH_SIZE = 128
|
|
|
|
|
|
def _split_texts_into_batches(texts: List[str], batch_size: int) -> Iterator[List[str]]:
|
|
texts_itr = iter(texts)
|
|
return iter(lambda: list(islice(texts_itr, batch_size)), [])
|
|
|
|
|
|
class AI21Embeddings(Embeddings, AI21Base):
|
|
"""AI21 Embeddings embedding model.
|
|
To use, you should have the 'AI21_API_KEY' environment variable set
|
|
or pass as a named parameter to the constructor.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_ai21 import AI21Embeddings
|
|
|
|
embeddings = AI21Embeddings()
|
|
query_result = embeddings.embed_query("Hello embeddings world!")
|
|
"""
|
|
|
|
batch_size: int = _DEFAULT_BATCH_SIZE
|
|
"""Maximum number of texts to embed in each batch"""
|
|
|
|
def embed_documents(
|
|
self,
|
|
texts: List[str],
|
|
*,
|
|
batch_size: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> List[List[float]]:
|
|
"""Embed search docs."""
|
|
return self._send_embeddings(
|
|
texts=texts,
|
|
batch_size=batch_size or self.batch_size,
|
|
embed_type=EmbedType.SEGMENT,
|
|
**kwargs,
|
|
)
|
|
|
|
def embed_query(
|
|
self,
|
|
text: str,
|
|
*,
|
|
batch_size: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> List[float]:
|
|
"""Embed query text."""
|
|
return self._send_embeddings(
|
|
texts=[text],
|
|
batch_size=batch_size or self.batch_size,
|
|
embed_type=EmbedType.QUERY,
|
|
**kwargs,
|
|
)[0]
|
|
|
|
def _send_embeddings(
|
|
self, texts: List[str], *, batch_size: int, embed_type: EmbedType, **kwargs: Any
|
|
) -> List[List[float]]:
|
|
chunks = _split_texts_into_batches(texts, batch_size)
|
|
responses = [
|
|
self.client.embed.create(
|
|
texts=chunk,
|
|
type=embed_type,
|
|
**kwargs,
|
|
)
|
|
for chunk in chunks
|
|
]
|
|
|
|
return [
|
|
result.embedding for response in responses for result in response.results
|
|
]
|