diff --git a/langchain/embeddings/base.py b/langchain/embeddings/base.py index 4a56cd6a..6dd700c4 100644 --- a/langchain/embeddings/base.py +++ b/langchain/embeddings/base.py @@ -13,3 +13,11 @@ class Embeddings(ABC): @abstractmethod def embed_query(self, text: str) -> List[float]: """Embed query text.""" + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed search docs.""" + raise NotImplementedError + + async def aembed_query(self, text: str) -> List[float]: + """Embed query text.""" + raise NotImplementedError diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 56813d2a..9c233230 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -18,6 +18,7 @@ from typing import ( import numpy as np from pydantic import BaseModel, Extra, root_validator from tenacity import ( + AsyncRetrying, before_sleep_log, retry, retry_if_exception_type, @@ -53,6 +54,38 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any ) +def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any: + import openai + + min_seconds = 4 + max_seconds = 10 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + async_retrying = AsyncRetrying( + reraise=True, + stop=stop_after_attempt(embeddings.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def wrap(func: Callable) -> Callable: + async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: + async for _ in async_retrying: + return await func(*args, **kwargs) + raise AssertionError("this is unreachable") + + return wrapped_f + + return wrap + + def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: """Use tenacity to retry the embedding call.""" retry_decorator = _create_retry_decorator(embeddings) @@ -64,6 +97,16 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: return _embed_with_retry(**kwargs) +async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: + """Use tenacity to retry the embedding call.""" + + @_async_retry_decorator(embeddings) + async def _async_embed_with_retry(**kwargs: Any) -> Any: + return await embeddings.client.acreate(**kwargs) + + return await _async_embed_with_retry(**kwargs) + + class OpenAIEmbeddings(BaseModel, Embeddings): """Wrapper around OpenAI embedding models. @@ -269,6 +312,70 @@ class OpenAIEmbeddings(BaseModel, Embeddings): return embeddings + # please refer to + # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + async def _aget_len_safe_embeddings( + self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + ) -> List[List[float]]: + embeddings: List[List[float]] = [[] for _ in range(len(texts))] + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for OpenAIEmbeddings. " + "Please install it with `pip install tiktoken`." + ) + + tokens = [] + indices = [] + encoding = tiktoken.model.encoding_for_model(self.model) + for i, text in enumerate(texts): + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + token = encoding.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) + for j in range(0, len(token), self.embedding_ctx_length): + tokens += [token[j : j + self.embedding_ctx_length]] + indices += [i] + + batched_embeddings = [] + _chunk_size = chunk_size or self.chunk_size + for i in range(0, len(tokens), _chunk_size): + response = await async_embed_with_retry( + self, + input=tokens[i : i + _chunk_size], + **self._invocation_params, + ) + batched_embeddings += [r["embedding"] for r in response["data"]] + + results: List[List[List[float]]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] + for i in range(len(indices)): + results[indices[i]].append(batched_embeddings[i]) + num_tokens_in_batch[indices[i]].append(len(tokens[i])) + + for i in range(len(texts)): + _result = results[i] + if len(_result) == 0: + average = ( + await async_embed_with_retry( + self, + input="", + **self._invocation_params, + ) + )["data"][0]["embedding"] + else: + average = np.average(_result, axis=0, weights=num_tokens_in_batch[i]) + embeddings[i] = (average / np.linalg.norm(average)).tolist() + + return embeddings + def _embedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to OpenAI's embedding endpoint.""" # handle large input text @@ -287,6 +394,24 @@ class OpenAIEmbeddings(BaseModel, Embeddings): "data" ][0]["embedding"] + async def _aembedding_func(self, text: str, *, engine: str) -> List[float]: + """Call out to OpenAI's embedding endpoint.""" + # handle large input text + if len(text) > self.embedding_ctx_length: + return (await self._aget_len_safe_embeddings([text], engine=engine))[0] + else: + if self.model.endswith("001"): + # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + return ( + await async_embed_with_retry( + self, + input=[text], + **self._invocation_params, + ) + )["data"][0]["embedding"] + def embed_documents( self, texts: List[str], chunk_size: Optional[int] = 0 ) -> List[List[float]]: @@ -304,6 +429,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings): # than the maximum context and use length-safe embedding function. return self._get_len_safe_embeddings(texts, engine=self.deployment) + async def aembed_documents( + self, texts: List[str], chunk_size: Optional[int] = 0 + ) -> List[List[float]]: + """Call out to OpenAI's embedding endpoint async for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + # NOTE: to keep things simple, we assume the list may contain texts longer + # than the maximum context and use length-safe embedding function. + return await self._aget_len_safe_embeddings(texts, engine=self.deployment) + def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. @@ -315,3 +457,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """ embedding = self._embedding_func(text, engine=self.deployment) return embedding + + async def aembed_query(self, text: str) -> List[float]: + """Call out to OpenAI's embedding endpoint async for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = await self._aembedding_func(text, engine=self.deployment) + return embedding diff --git a/tests/integration_tests/embeddings/test_openai.py b/tests/integration_tests/embeddings/test_openai.py index 1dba7553..6033d3f1 100644 --- a/tests/integration_tests/embeddings/test_openai.py +++ b/tests/integration_tests/embeddings/test_openai.py @@ -1,6 +1,7 @@ """Test openai embeddings.""" import numpy as np import openai +import pytest from langchain.embeddings.openai import OpenAIEmbeddings @@ -26,6 +27,19 @@ def test_openai_embedding_documents_multiple() -> None: assert len(output[2]) == 1536 +@pytest.mark.asyncio +async def test_openai_embedding_documents_async_multiple() -> None: + """Test openai embeddings.""" + documents = ["foo bar", "bar foo", "foo"] + embedding = OpenAIEmbeddings(chunk_size=2) + embedding.embedding_ctx_length = 8191 + output = await embedding.aembed_documents(documents) + assert len(output) == 3 + assert len(output[0]) == 1536 + assert len(output[1]) == 1536 + assert len(output[2]) == 1536 + + def test_openai_embedding_query() -> None: """Test openai embeddings.""" document = "foo bar" @@ -34,6 +48,15 @@ def test_openai_embedding_query() -> None: assert len(output) == 1536 +@pytest.mark.asyncio +async def test_openai_embedding_async_query() -> None: + """Test openai embeddings.""" + document = "foo bar" + embedding = OpenAIEmbeddings() + output = await embedding.aembed_query(document) + assert len(output) == 1536 + + def test_openai_embedding_with_empty_string() -> None: """Test openai embeddings with empty string.""" document = ["", "abc"]