from __future__ import annotations import json import logging from typing import ( Any, Callable, Dict, List, Optional, Tuple, Union, cast, ) import requests from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from tenacity import ( before_sleep_log, retry, stop_after_attempt, wait_exponential, ) logger = logging.getLogger(__name__) def _create_retry_decorator(embeddings: VoyageEmbeddings) -> Callable[[Any], Any]: min_seconds = 4 max_seconds = 10 # Wait 2^x * 1 second between each retry starting with # 4 seconds, then up to 10 seconds, then 10 seconds afterwards return retry( reraise=True, stop=stop_after_attempt(embeddings.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), before_sleep=before_sleep_log(logger, logging.WARNING), ) def _check_response(response: dict) -> dict: if "data" not in response: raise RuntimeError(f"Voyage API Error. Message: {json.dumps(response)}") return response def embed_with_retry(embeddings: VoyageEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" retry_decorator = _create_retry_decorator(embeddings) @retry_decorator def _embed_with_retry(**kwargs: Any) -> Any: response = requests.post(**kwargs) return _check_response(response.json()) return _embed_with_retry(**kwargs) class VoyageEmbeddings(BaseModel, Embeddings): """Voyage embedding models. To use, you should have the environment variable ``VOYAGE_API_KEY`` set with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.embeddings import VoyageEmbeddings voyage = VoyageEmbeddings(voyage_api_key="your-api-key") text = "This is a test query." query_result = voyage.embed_query(text) """ model: str = "voyage-01" voyage_api_base: str = "https://api.voyageai.com/v1/embeddings" voyage_api_key: Optional[SecretStr] = None batch_size: int = 8 """Maximum number of texts to embed in each API request.""" max_retries: int = 6 """Maximum number of retries to make when generating.""" request_timeout: Optional[Union[float, Tuple[float, float]]] = None """Timeout in seconds for the API request.""" show_progress_bar: bool = False """Whether to show a progress bar when embedding. Must have tqdm installed if set to True.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["voyage_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY") ) return values def _invocation_params( self, input: List[str], input_type: Optional[str] = None ) -> Dict: api_key = cast(SecretStr, self.voyage_api_key).get_secret_value() params = { "url": self.voyage_api_base, "headers": {"Authorization": f"Bearer {api_key}"}, "json": {"model": self.model, "input": input, "input_type": input_type}, "timeout": self.request_timeout, } return params def _get_embeddings( self, texts: List[str], batch_size: Optional[int] = None, input_type: Optional[str] = None, ) -> List[List[float]]: embeddings: List[List[float]] = [] if batch_size is None: batch_size = self.batch_size if self.show_progress_bar: try: from tqdm.auto import tqdm except ImportError as e: raise ImportError( "Must have tqdm installed if `show_progress_bar` is set to True. " "Please install with `pip install tqdm`." ) from e _iter = tqdm(range(0, len(texts), batch_size)) else: _iter = range(0, len(texts), batch_size) if input_type and input_type not in ["query", "document"]: raise ValueError( f"input_type {input_type} is invalid. Options: None, 'query', " "'document'." ) for i in _iter: response = embed_with_retry( self, **self._invocation_params( input=texts[i : i + batch_size], input_type=input_type ), ) embeddings.extend(r["embedding"] for r in response["data"]) return embeddings def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call out to Voyage Embedding endpoint for embedding search docs. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each text. """ return self._get_embeddings( texts, batch_size=self.batch_size, input_type="document" ) def embed_query(self, text: str) -> List[float]: """Call out to Voyage Embedding endpoint for embedding query text. Args: text: The text to embed. Returns: Embedding for the text. """ return self._get_embeddings([text], input_type="query")[0] def embed_general_texts( self, texts: List[str], *, input_type: Optional[str] = None ) -> List[List[float]]: """Call out to Voyage Embedding endpoint for embedding general text. Args: texts: The list of texts to embed. input_type: Type of the input text. Default to None, meaning the type is unspecified. Other options: query, document. Returns: Embedding for the text. """ return self._get_embeddings( texts, batch_size=self.batch_size, input_type=input_type )