Handle length safe embedding only if needed (#3723)

Re: https://github.com/hwchase17/langchain/issues/3722

Copy pasting context from the issue:


1bf1c37c0c/langchain/embeddings/openai.py (L210-L211)

Means that the length safe embedding method is "always" used, initial
implementation https://github.com/hwchase17/langchain/pull/991 has the
`embedding_ctx_length` set to -1 (meaning you had to opt-in for the
length safe method), https://github.com/hwchase17/langchain/pull/2330
changed that to max length of OpenAI embeddings v2, meaning the length
safe method is used at all times.

How about changing that if branch to use length safe method only when
needed, meaning when the text is longer than the max context length?
This commit is contained in:
Rafal Wojdyla 2023-04-29 04:10:04 +01:00 committed by GitHub
parent 40f6e60e68
commit 37ed6f2177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -207,7 +207,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def _embedding_func(self, text: str, *, engine: str) -> List[float]: def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint.""" """Call out to OpenAI's embedding endpoint."""
# handle large input text # handle large input text
if self.embedding_ctx_length > 0: if len(text) > self.embedding_ctx_length:
return self._get_len_safe_embeddings([text], engine=engine)[0] return self._get_len_safe_embeddings([text], engine=engine)[0]
else: else:
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
@ -229,20 +229,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
Returns: Returns:
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
# handle batches of large input text # NOTE: to keep things simple, we assume the list may contain texts longer
if self.embedding_ctx_length > 0: # than the maximum context and use length-safe embedding function.
return self._get_len_safe_embeddings(texts, engine=self.deployment) return self._get_len_safe_embeddings(texts, engine=self.deployment)
else:
results = []
_chunk_size = chunk_size or self.chunk_size
for i in range(0, len(texts), _chunk_size):
response = embed_with_retry(
self,
input=texts[i : i + _chunk_size],
engine=self.deployment,
)
results += [r["embedding"] for r in response["data"]]
return results
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text. """Call out to OpenAI's embedding endpoint for embedding query text.