@ -18,6 +18,7 @@ from typing import (
import numpy as np
from pydantic import BaseModel , Extra , root_validator
from tenacity import (
AsyncRetrying ,
before_sleep_log ,
retry ,
retry_if_exception_type ,
@ -53,6 +54,38 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any
)
def _async_retry_decorator ( embeddings : OpenAIEmbeddings ) - > Any :
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
async_retrying = AsyncRetrying (
reraise = True ,
stop = stop_after_attempt ( embeddings . max_retries ) ,
wait = wait_exponential ( multiplier = 1 , min = min_seconds , max = max_seconds ) ,
retry = (
retry_if_exception_type ( openai . error . Timeout )
| retry_if_exception_type ( openai . error . APIError )
| retry_if_exception_type ( openai . error . APIConnectionError )
| retry_if_exception_type ( openai . error . RateLimitError )
| retry_if_exception_type ( openai . error . ServiceUnavailableError )
) ,
before_sleep = before_sleep_log ( logger , logging . WARNING ) ,
)
def wrap ( func : Callable ) - > Callable :
async def wrapped_f ( * args : Any , * * kwargs : Any ) - > Callable :
async for _ in async_retrying :
return await func ( * args , * * kwargs )
raise AssertionError ( " this is unreachable " )
return wrapped_f
return wrap
def embed_with_retry ( embeddings : OpenAIEmbeddings , * * kwargs : Any ) - > Any :
""" Use tenacity to retry the embedding call. """
retry_decorator = _create_retry_decorator ( embeddings )
@ -64,6 +97,16 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
return _embed_with_retry ( * * kwargs )
async def async_embed_with_retry ( embeddings : OpenAIEmbeddings , * * kwargs : Any ) - > Any :
""" Use tenacity to retry the embedding call. """
@_async_retry_decorator ( embeddings )
async def _async_embed_with_retry ( * * kwargs : Any ) - > Any :
return await embeddings . client . acreate ( * * kwargs )
return await _async_embed_with_retry ( * * kwargs )
class OpenAIEmbeddings ( BaseModel , Embeddings ) :
""" Wrapper around OpenAI embedding models.
@ -269,6 +312,70 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
return embeddings
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
async def _aget_len_safe_embeddings (
self , texts : List [ str ] , * , engine : str , chunk_size : Optional [ int ] = None
) - > List [ List [ float ] ] :
embeddings : List [ List [ float ] ] = [ [ ] for _ in range ( len ( texts ) ) ]
try :
import tiktoken
except ImportError :
raise ImportError (
" Could not import tiktoken python package. "
" This is needed in order to for OpenAIEmbeddings. "
" Please install it with `pip install tiktoken`. "
)
tokens = [ ]
indices = [ ]
encoding = tiktoken . model . encoding_for_model ( self . model )
for i , text in enumerate ( texts ) :
if self . model . endswith ( " 001 " ) :
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text . replace ( " \n " , " " )
token = encoding . encode (
text ,
allowed_special = self . allowed_special ,
disallowed_special = self . disallowed_special ,
)
for j in range ( 0 , len ( token ) , self . embedding_ctx_length ) :
tokens + = [ token [ j : j + self . embedding_ctx_length ] ]
indices + = [ i ]
batched_embeddings = [ ]
_chunk_size = chunk_size or self . chunk_size
for i in range ( 0 , len ( tokens ) , _chunk_size ) :
response = await async_embed_with_retry (
self ,
input = tokens [ i : i + _chunk_size ] ,
* * self . _invocation_params ,
)
batched_embeddings + = [ r [ " embedding " ] for r in response [ " data " ] ]
results : List [ List [ List [ float ] ] ] = [ [ ] for _ in range ( len ( texts ) ) ]
num_tokens_in_batch : List [ List [ int ] ] = [ [ ] for _ in range ( len ( texts ) ) ]
for i in range ( len ( indices ) ) :
results [ indices [ i ] ] . append ( batched_embeddings [ i ] )
num_tokens_in_batch [ indices [ i ] ] . append ( len ( tokens [ i ] ) )
for i in range ( len ( texts ) ) :
_result = results [ i ]
if len ( _result ) == 0 :
average = (
await async_embed_with_retry (
self ,
input = " " ,
* * self . _invocation_params ,
)
) [ " data " ] [ 0 ] [ " embedding " ]
else :
average = np . average ( _result , axis = 0 , weights = num_tokens_in_batch [ i ] )
embeddings [ i ] = ( average / np . linalg . norm ( average ) ) . tolist ( )
return embeddings
def _embedding_func ( self , text : str , * , engine : str ) - > List [ float ] :
""" Call out to OpenAI ' s embedding endpoint. """
# handle large input text
@ -287,6 +394,24 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
" data "
] [ 0 ] [ " embedding " ]
async def _aembedding_func ( self , text : str , * , engine : str ) - > List [ float ] :
""" Call out to OpenAI ' s embedding endpoint. """
# handle large input text
if len ( text ) > self . embedding_ctx_length :
return ( await self . _aget_len_safe_embeddings ( [ text ] , engine = engine ) ) [ 0 ]
else :
if self . model . endswith ( " 001 " ) :
# See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
# replace newlines, which can negatively affect performance.
text = text . replace ( " \n " , " " )
return (
await async_embed_with_retry (
self ,
input = [ text ] ,
* * self . _invocation_params ,
)
) [ " data " ] [ 0 ] [ " embedding " ]
def embed_documents (
self , texts : List [ str ] , chunk_size : Optional [ int ] = 0
) - > List [ List [ float ] ] :
@ -304,6 +429,23 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
# than the maximum context and use length-safe embedding function.
return self . _get_len_safe_embeddings ( texts , engine = self . deployment )
async def aembed_documents (
self , texts : List [ str ] , chunk_size : Optional [ int ] = 0
) - > List [ List [ float ] ] :
""" Call out to OpenAI ' s embedding endpoint async for embedding search docs.
Args :
texts : The list of texts to embed .
chunk_size : The chunk size of embeddings . If None , will use the chunk size
specified by the class .
Returns :
List of embeddings , one for each text .
"""
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return await self . _aget_len_safe_embeddings ( texts , engine = self . deployment )
def embed_query ( self , text : str ) - > List [ float ] :
""" Call out to OpenAI ' s embedding endpoint for embedding query text.
@ -315,3 +457,15 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""
embedding = self . _embedding_func ( text , engine = self . deployment )
return embedding
async def aembed_query ( self , text : str ) - > List [ float ] :
""" Call out to OpenAI ' s embedding endpoint async for embedding query text.
Args :
text : The text to embed .
Returns :
Embedding for the text .
"""
embedding = await self . _aembedding_func ( text , engine = self . deployment )
return embedding