Handle length safe embedding only if needed (#3723)

Re: https://github.com/hwchase17/langchain/issues/3722

Copy pasting context from the issue:


1bf1c37c0c/langchain/embeddings/openai.py (L210-L211)

Means that the length safe embedding method is "always" used, initial
implementation https://github.com/hwchase17/langchain/pull/991 has the
`embedding_ctx_length` set to -1 (meaning you had to opt-in for the
length safe method), https://github.com/hwchase17/langchain/pull/2330
changed that to max length of OpenAI embeddings v2, meaning the length
safe method is used at all times.

How about changing that if branch to use length safe method only when
needed, meaning when the text is longer than the max context length?
fix_agent_callbacks
Rafal Wojdyla 1 year ago committed by GitHub
parent 40f6e60e68
commit 37ed6f2177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -207,7 +207,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint."""
# handle large input text
if self.embedding_ctx_length > 0:
if len(text) > self.embedding_ctx_length:
return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
# replace newlines, which can negatively affect performance.
@ -229,20 +229,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
# handle batches of large input text
if self.embedding_ctx_length > 0:
return self._get_len_safe_embeddings(texts, engine=self.deployment)
else:
results = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(texts), _chunk_size):
response = embed_with_retry(
self,
input=texts[i : i + _chunk_size],
engine=self.deployment,
)
results += [r["embedding"] for r in response["data"]]
return results
# NOTE: to keep things simple, we assume the list may contain texts longer
# than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment)
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.

Loading…
Cancel
Save