OpenAIEmbeddings: Add optional an optional parameter to skip empty embeddings (#10196)

## Description

### Issue
This pull request addresses a lingering issue identified in PR #7070. In
that previous pull request, an attempt was made to address the problem
of empty embeddings when using the `OpenAIEmbeddings` class. While PR
#7070 introduced a mechanism to retry requests for embeddings, it didn't
fully resolve the issue as empty embeddings still occasionally
persisted.

### Problem
In certain specific use cases, empty embeddings can be encountered when
requesting data from the OpenAI API. In some cases, these empty
embeddings can be skipped or removed without affecting the functionality
of the application. However, they might not always be resolved through
retries, and their presence can adversely affect the functionality of
applications relying on the `OpenAIEmbeddings` class.

### Solution
To provide a more robust solution for handling empty embeddings, we
propose the introduction of an optional parameter, `skip_empty`, in the
`OpenAIEmbeddings` class. When set to `True`, this parameter will enable
the behavior of automatically skipping empty embeddings, ensuring that
problematic empty embeddings do not disrupt the processing flow. The
developer will be able to optionally toggle this behavior if needed
without disrupting the application flow.

## Changes Made
- Added an optional parameter, `skip_empty`, to the `OpenAIEmbeddings`
class.
- When `skip_empty` is set to `True`, empty embeddings are automatically
skipped without causing errors or disruptions.

### Example Usage
```python
from openai.embeddings import OpenAIEmbeddings

# Initialize the OpenAIEmbeddings class with skip_empty=True
embeddings = OpenAIEmbeddings(api_key="your_api_key", skip_empty=True)

# Request embeddings, empty embeddings are automatically skipped. docs is a variable containing the already splitted text.
results = embeddings.embed_documents(docs)

# Process results without interruption from empty embeddings
```
This commit is contained in:
ElReyZero 2023-09-04 16:10:36 -05:00 committed by GitHub
parent 8998060d85
commit 5dbae94e04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,8 +87,8 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
def _check_response(response: dict) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]):
def _check_response(response: dict, skip_empty: bool = False) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
import openai
raise openai.error.APIError("OpenAI API returned an empty embedding")
@ -102,7 +102,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response)
return _check_response(response, skip_empty=embeddings.skip_empty)
return _embed_with_retry(**kwargs)
@ -113,7 +113,7 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) ->
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
return _check_response(response)
return _check_response(response, skip_empty=embeddings.skip_empty)
return await _async_embed_with_retry(**kwargs)
@ -196,6 +196,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
skip_empty: bool = False
"""Whether to skip empty strings when embedding or raise an error.
Defaults to not skipping."""
class Config:
"""Configuration for this pydantic object."""
@ -371,6 +374,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
if self.skip_empty and len(batched_embeddings[i]) == 1:
continue
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))