diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index e52b695a..7caa304f 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -207,7 +207,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): def _embedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to OpenAI's embedding endpoint.""" # handle large input text - if self.embedding_ctx_length > 0: + if len(text) > self.embedding_ctx_length: return self._get_len_safe_embeddings([text], engine=engine)[0] else: # replace newlines, which can negatively affect performance. @@ -229,20 +229,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ - # handle batches of large input text - if self.embedding_ctx_length > 0: - return self._get_len_safe_embeddings(texts, engine=self.deployment) - else: - results = [] - _chunk_size = chunk_size or self.chunk_size - for i in range(0, len(texts), _chunk_size): - response = embed_with_retry( - self, - input=texts[i : i + _chunk_size], - engine=self.deployment, - ) - results += [r["embedding"] for r in response["data"]] - return results + # NOTE: to keep things simple, we assume the list may contain texts longer + # than the maximum context and use length-safe embedding function. + return self._get_len_safe_embeddings(texts, engine=self.deployment) def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text.