From 37ed6f217776487b4417893d361f6c4ebe3e2e4d Mon Sep 17 00:00:00 2001 From: Rafal Wojdyla Date: Sat, 29 Apr 2023 04:10:04 +0100 Subject: [PATCH] Handle length safe embedding only if needed (#3723) Re: https://github.com/hwchase17/langchain/issues/3722 Copy pasting context from the issue: https://github.com/hwchase17/langchain/blob/1bf1c37c0cccb7c8c73d87ace27cf742f814dbe5/langchain/embeddings/openai.py#L210-L211 Means that the length safe embedding method is "always" used, initial implementation https://github.com/hwchase17/langchain/pull/991 has the `embedding_ctx_length` set to -1 (meaning you had to opt-in for the length safe method), https://github.com/hwchase17/langchain/pull/2330 changed that to max length of OpenAI embeddings v2, meaning the length safe method is used at all times. How about changing that if branch to use length safe method only when needed, meaning when the text is longer than the max context length? --- langchain/embeddings/openai.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index e52b695a..7caa304f 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -207,7 +207,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): def _embedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to OpenAI's embedding endpoint.""" # handle large input text - if self.embedding_ctx_length > 0: + if len(text) > self.embedding_ctx_length: return self._get_len_safe_embeddings([text], engine=engine)[0] else: # replace newlines, which can negatively affect performance. @@ -229,20 +229,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - # handle batches of large input text - if self.embedding_ctx_length > 0: - return self._get_len_safe_embeddings(texts, engine=self.deployment) - else: - results = [] - _chunk_size = chunk_size or self.chunk_size - for i in range(0, len(texts), _chunk_size): - response = embed_with_retry( - self, - input=texts[i : i + _chunk_size], - engine=self.deployment, - ) - results += [r["embedding"] for r in response["data"]] - return results + # NOTE: to keep things simple, we assume the list may contain texts longer + # than the maximum context and use length-safe embedding function. + return self._get_len_safe_embeddings(texts, engine=self.deployment) def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text.