From 28bef6f87dd2ba9ab3b1ea1752aadb28d4c80b14 Mon Sep 17 00:00:00 2001 From: vr140 Date: Mon, 10 Apr 2023 20:53:56 -0700 Subject: [PATCH] Clean up OpenAI Embeddings to fix method name and comments (#2687) **Problem:** OpenAI Embeddings has a few minor issues: method name and comment for _completion_with_retry seems to be a copypasta error and a few comments around usage of embedding_ctx_length seem to be incorrect. **Solution:** Clean up issues. --------- Co-authored-by: Vijay Rajaram --- langchain/embeddings/openai.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 09daf764..82d248bb 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -43,14 +43,14 @@ def _create_retry_decorator(embeddings: OpenAIEmbeddings) -> Callable[[Any], Any def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: - """Use tenacity to retry the completion call.""" + """Use tenacity to retry the embedding call.""" retry_decorator = _create_retry_decorator(embeddings) @retry_decorator - def _completion_with_retry(**kwargs: Any) -> Any: + def _embed_with_retry(**kwargs: Any) -> Any: return embeddings.client.create(**kwargs) - return _completion_with_retry(**kwargs) + return _embed_with_retry(**kwargs) class OpenAIEmbeddings(BaseModel, Embeddings): @@ -231,10 +231,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): def _embedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to OpenAI's embedding endpoint.""" - # replace newlines, which can negatively affect performance. + # handle large input text if self.embedding_ctx_length > 0: return self._get_len_safe_embeddings([text], engine=engine)[0] else: + # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") return embed_with_retry(self, input=[text], engine=engine)["data"][0][ "embedding" @@ -253,7 +254,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - # handle large batches of texts + # handle batches of large input text if self.embedding_ctx_length > 0: return self._get_len_safe_embeddings(texts, engine=self.document_model_name) else: @@ -275,7 +276,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): text: The text to embed. Returns: - Embeddings for the text. + Embedding for the text. """ embedding = self._embedding_func(text, engine=self.query_model_name) return embedding