mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
OpenAIEmbeddings: Add optional an optional parameter to skip empty embeddings (#10196)
## Description ### Issue This pull request addresses a lingering issue identified in PR #7070. In that previous pull request, an attempt was made to address the problem of empty embeddings when using the `OpenAIEmbeddings` class. While PR #7070 introduced a mechanism to retry requests for embeddings, it didn't fully resolve the issue as empty embeddings still occasionally persisted. ### Problem In certain specific use cases, empty embeddings can be encountered when requesting data from the OpenAI API. In some cases, these empty embeddings can be skipped or removed without affecting the functionality of the application. However, they might not always be resolved through retries, and their presence can adversely affect the functionality of applications relying on the `OpenAIEmbeddings` class. ### Solution To provide a more robust solution for handling empty embeddings, we propose the introduction of an optional parameter, `skip_empty`, in the `OpenAIEmbeddings` class. When set to `True`, this parameter will enable the behavior of automatically skipping empty embeddings, ensuring that problematic empty embeddings do not disrupt the processing flow. The developer will be able to optionally toggle this behavior if needed without disrupting the application flow. ## Changes Made - Added an optional parameter, `skip_empty`, to the `OpenAIEmbeddings` class. - When `skip_empty` is set to `True`, empty embeddings are automatically skipped without causing errors or disruptions. ### Example Usage ```python from openai.embeddings import OpenAIEmbeddings # Initialize the OpenAIEmbeddings class with skip_empty=True embeddings = OpenAIEmbeddings(api_key="your_api_key", skip_empty=True) # Request embeddings, empty embeddings are automatically skipped. docs is a variable containing the already splitted text. results = embeddings.embed_documents(docs) # Process results without interruption from empty embeddings ```
This commit is contained in:
parent
8998060d85
commit
5dbae94e04
@ -87,8 +87,8 @@ def _async_retry_decorator(embeddings: OpenAIEmbeddings) -> Any:
|
||||
|
||||
|
||||
# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
|
||||
def _check_response(response: dict) -> dict:
|
||||
if any(len(d["embedding"]) == 1 for d in response["data"]):
|
||||
def _check_response(response: dict, skip_empty: bool = False) -> dict:
|
||||
if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
|
||||
import openai
|
||||
|
||||
raise openai.error.APIError("OpenAI API returned an empty embedding")
|
||||
@ -102,7 +102,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
|
||||
@retry_decorator
|
||||
def _embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = embeddings.client.create(**kwargs)
|
||||
return _check_response(response)
|
||||
return _check_response(response, skip_empty=embeddings.skip_empty)
|
||||
|
||||
return _embed_with_retry(**kwargs)
|
||||
|
||||
@ -113,7 +113,7 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) ->
|
||||
@_async_retry_decorator(embeddings)
|
||||
async def _async_embed_with_retry(**kwargs: Any) -> Any:
|
||||
response = await embeddings.client.acreate(**kwargs)
|
||||
return _check_response(response)
|
||||
return _check_response(response, skip_empty=embeddings.skip_empty)
|
||||
|
||||
return await _async_embed_with_retry(**kwargs)
|
||||
|
||||
@ -196,6 +196,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
"""Whether to show a progress bar when embedding."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
skip_empty: bool = False
|
||||
"""Whether to skip empty strings when embedding or raise an error.
|
||||
Defaults to not skipping."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -371,6 +374,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
||||
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
if self.skip_empty and len(batched_embeddings[i]) == 1:
|
||||
continue
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user