diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 3b8fa5559f..189df58d2a 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -1,12 +1,57 @@ """Wrapper around OpenAI embedding models.""" -from typing import Any, Dict, List, Optional +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional import numpy as np from pydantic import BaseModel, Extra, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from langchain.embeddings.base import Embeddings from langchain.utils import get_from_dict_or_env +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any]: + import openai + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(embeddings) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return embeddings.client.create(**kwargs) + + return _completion_with_retry(**kwargs) + class OpenAIEmbeddings(BaseModel, Embeddings): """Wrapper around OpenAI embedding models. @@ -27,6 +72,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings): query_model_name: str = "text-embedding-ada-002" embedding_ctx_length: int = -1 openai_api_key: Optional[str] = None + chunk_size: int = 1000 + """Maximum number of texts to embed in each batch""" + max_retries: int = 6 + """Maximum number of retries to make when generating.""" class Config: """Configuration for this pydantic object.""" @@ -74,7 +123,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb def _get_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: int = 1000 + self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None ) -> List[List[float]]: embeddings: List[List[float]] = [[] for i in range(len(texts))] try: @@ -92,9 +141,12 @@ class OpenAIEmbeddings(BaseModel, Embeddings): indices += [i] batched_embeddings = [] - for i in range(0, len(tokens), chunk_size): - response = self.client.create( - input=tokens[i : i + chunk_size], engine=self.document_model_name + _chunk_size = chunk_size or self.chunk_size + for i in range(0, len(tokens), _chunk_size): + response = embed_with_retry( + self, + input=tokens[i : i + _chunk_size], + engine=self.document_model_name, ) batched_embeddings += [r["embedding"] for r in response["data"]] @@ -124,33 +176,34 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return self._get_len_safe_embeddings([text], engine=engine)[0] else: text = text.replace("\n", " ") - return self.client.create(input=[text], engine=engine)["data"][0][ + return embed_with_retry(self, input=[text], engine=engine)["data"][0][ "embedding" ] def embed_documents( - self, texts: List[str], chunk_size: int = 1000 + self, texts: List[str], chunk_size: Optional[int] = 0 ) -> List[List[float]]: """Call out to OpenAI's embedding endpoint for embedding search docs. Args: texts: The list of texts to embed. - chunk_size: The maximum number of texts to send to OpenAI at once - (max 1000). + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. Returns: List of embeddings, one for each text. """ # handle large batches of texts if self.embedding_ctx_length > 0: - return self._get_len_safe_embeddings( - texts, engine=self.document_model_name, chunk_size=chunk_size - ) + return self._get_len_safe_embeddings(texts, engine=self.document_model_name) else: results = [] - for i in range(0, len(texts), chunk_size): - response = self.client.create( - input=texts[i : i + chunk_size], engine=self.document_model_name + _chunk_size = chunk_size or self.chunk_size + for i in range(0, len(texts), _chunk_size): + response = embed_with_retry( + self, + input=texts[i : i + _chunk_size], + engine=self.document_model_name, ) results += [r["embedding"] for r in response["data"]] return results diff --git a/tests/integration_tests/embeddings/test_openai.py b/tests/integration_tests/embeddings/test_openai.py index 3fdb1e532e..9aa7d19c78 100644 --- a/tests/integration_tests/embeddings/test_openai.py +++ b/tests/integration_tests/embeddings/test_openai.py @@ -14,9 +14,9 @@ def test_openai_embedding_documents() -> None: def test_openai_embedding_documents_multiple() -> None: """Test openai embeddings.""" documents = ["foo bar", "bar foo", "foo"] - embedding = OpenAIEmbeddings() + embedding = OpenAIEmbeddings(chunk_size=2) embedding.embedding_ctx_length = 8191 - output = embedding.embed_documents(documents, chunk_size=2) + output = embedding.embed_documents(documents) assert len(output) == 3 assert len(output[0]) == 1536 assert len(output[1]) == 1536