langchain/libs/partners/mistralai/langchain_mistralai/embeddings.py
2024-04-23 20:56:42 +00:00

213 lines
6.7 KiB
Python

import asyncio
import logging
import warnings
from typing import Dict, Iterable, List, Optional
import httpx
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tokenizers import Tokenizer # type: ignore
logger = logging.getLogger(__name__)
MAX_TOKENS = 16_000
class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
def encode_batch(self, texts: List[str]) -> List[List[str]]:
return [list(text) for text in texts]
class MistralAIEmbeddings(BaseModel, Embeddings):
"""MistralAI embedding models.
To use, set the environment variable `MISTRAL_API_KEY` is set with your API key or
pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_mistralai import MistralAIEmbeddings
mistral = MistralAIEmbeddings(
model="mistral-embed",
api_key="my-api-key"
)
"""
client: httpx.Client = Field(default=None) #: :meta private:
async_client: httpx.AsyncClient = Field(default=None) #: :meta private:
mistral_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5
timeout: int = 120
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)
model: str = "mistral-embed"
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate configuration."""
values["mistral_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
)
api_key_str = values["mistral_api_key"].get_secret_value()
# todo: handle retries
values["client"] = httpx.Client(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
# todo: handle retries and max_concurrency
values["async_client"] = httpx.AsyncClient(
base_url=values["endpoint"],
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key_str}",
},
timeout=values["timeout"],
)
if values["tokenizer"] is None:
try:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
warnings.warn(
"Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the "
"HF_TOKEN environment variable to download the real tokenizer. "
"Falling back to a dummy tokenizer that uses `len()`."
)
values["tokenizer"] = DummyTokenizer()
return values
def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
"""Split a list of texts into batches of less than 16k tokens
for Mistral API."""
batch: List[str] = []
batch_tokens = 0
text_token_lengths = [
len(encoded) for encoded in self.tokenizer.encode_batch(texts)
]
for text, text_tokens in zip(texts, text_token_lengths):
if batch_tokens + text_tokens > MAX_TOKENS:
if len(batch) > 0:
# edge case where first batch exceeds max tokens
# should not yield an empty batch.
yield batch
batch = [text]
batch_tokens = text_tokens
else:
batch.append(text)
batch_tokens += text_tokens
if batch:
yield batch
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
try:
batch_responses = (
self.client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
)
for batch in self._get_batches(texts)
)
return [
list(map(float, embedding_obj["embedding"]))
for response in batch_responses
for embedding_obj in response.json()["data"]
]
except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}")
raise
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of document texts.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
try:
batch_responses = await asyncio.gather(
*[
self.async_client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
)
for batch in self._get_batches(texts)
]
)
return [
list(map(float, embedding_obj["embedding"]))
for response in batch_responses
for embedding_obj in response.json()["data"]
]
except Exception as e:
logger.error(f"An error occurred with MistralAI: {e}")
raise
def embed_query(self, text: str) -> List[float]:
"""Embed a single query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
"""Embed a single query text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
return (await self.aembed_documents([text]))[0]