mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Handle length safe embedding only if needed (#3723)
Re: https://github.com/hwchase17/langchain/issues/3722
Copy pasting context from the issue:
1bf1c37c0c/langchain/embeddings/openai.py (L210-L211)
Means that the length safe embedding method is "always" used, initial
implementation https://github.com/hwchase17/langchain/pull/991 has the
`embedding_ctx_length` set to -1 (meaning you had to opt-in for the
length safe method), https://github.com/hwchase17/langchain/pull/2330
changed that to max length of OpenAI embeddings v2, meaning the length
safe method is used at all times.
How about changing that if branch to use length safe method only when
needed, meaning when the text is longer than the max context length?
This commit is contained in:
parent
40f6e60e68
commit
37ed6f2177
@ -207,7 +207,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint."""
|
"""Call out to OpenAI's embedding endpoint."""
|
||||||
# handle large input text
|
# handle large input text
|
||||||
if self.embedding_ctx_length > 0:
|
if len(text) > self.embedding_ctx_length:
|
||||||
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
return self._get_len_safe_embeddings([text], engine=engine)[0]
|
||||||
else:
|
else:
|
||||||
# replace newlines, which can negatively affect performance.
|
# replace newlines, which can negatively affect performance.
|
||||||
@ -229,20 +229,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
# handle batches of large input text
|
# NOTE: to keep things simple, we assume the list may contain texts longer
|
||||||
if self.embedding_ctx_length > 0:
|
# than the maximum context and use length-safe embedding function.
|
||||||
return self._get_len_safe_embeddings(texts, engine=self.deployment)
|
return self._get_len_safe_embeddings(texts, engine=self.deployment)
|
||||||
else:
|
|
||||||
results = []
|
|
||||||
_chunk_size = chunk_size or self.chunk_size
|
|
||||||
for i in range(0, len(texts), _chunk_size):
|
|
||||||
response = embed_with_retry(
|
|
||||||
self,
|
|
||||||
input=texts[i : i + _chunk_size],
|
|
||||||
engine=self.deployment,
|
|
||||||
)
|
|
||||||
results += [r["embedding"] for r in response["data"]]
|
|
||||||
return results
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
"""Call out to OpenAI's embedding endpoint for embedding query text.
|
||||||
|
Loading…
Reference in New Issue
Block a user