Clean up OpenAI Embeddings to fix method name and comments (#2687)

**Problem:**

OpenAI Embeddings has a few minor issues: method name and comment for
_completion_with_retry seems to be a copypasta error and a few comments
around usage of embedding_ctx_length seem to be incorrect.

**Solution:**

Clean up issues.

---------

Co-authored-by: Vijay Rajaram <vrajaram3@gatech.edu>
fix_agent_callbacks
vr140 1 year ago committed by GitHub
parent ad3c5dd186
commit 28bef6f87d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,14 +43,14 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any
def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
"""Use tenacity to retry the embedding call."""
retry_decorator = _create_retry_decorator(embeddings)
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
def _embed_with_retry(**kwargs: Any) -> Any:
return embeddings.client.create(**kwargs)
return _completion_with_retry(**kwargs)
return _embed_with_retry(**kwargs)
class OpenAIEmbeddings(BaseModel, Embeddings):
@ -231,10 +231,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# replace newlines, which can negatively affect performance.
# handle large input text
if self.embedding_ctx_length > 0:
return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return embed_with_retry(self, input=[text], engine=engine)["data"][0][
"embedding"
@ -253,7 +254,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
# handle large batches of texts
# handle batches of large input text
if self.embedding_ctx_length > 0:
return self._get_len_safe_embeddings(texts, engine=self.document_model_name)
else:
@ -275,7 +276,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
text: The text to embed.
Returns:
Embeddings for the text.
Embedding for the text.
"""
embedding = self._embedding_func(text, engine=self.query_model_name)
return embedding

Loading…
Cancel
Save