From e08961ab2566ffbf3e9c5681464119d5a4144138 Mon Sep 17 00:00:00 2001 From: Hasegawa Yuya <52068175+Hase-U@users.noreply.github.com> Date: Thu, 16 Feb 2023 16:02:32 +0900 Subject: [PATCH] Fixed openai embeddings to be safe by batching them based on token size calculation. (#991) I modified the logic of the batch calculation for embedding according to this cookbook https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb --- langchain/embeddings/openai.py | 72 ++++++++++++++++--- .../embeddings/test_openai.py | 1 + 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index eb48c622..fbf6a1bc 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -1,6 +1,7 @@ """Wrapper around OpenAI embedding models.""" from typing import Any, Dict, List, Optional +import numpy as np from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings @@ -24,6 +25,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: document_model_name: str = "text-embedding-ada-002" query_model_name: str = "text-embedding-ada-002" + embedding_ctx_length: int = -1 openai_api_key: Optional[str] = None class Config: @@ -69,11 +71,62 @@ class OpenAIEmbeddings(BaseModel, Embeddings): ) return values + # please refer to + # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb + def _get_len_safe_embeddings( + self, texts: List[str], *, engine: str, chunk_size: int = 1000 + ) -> List[List[float]]: + embeddings: List[List[float]] = [[] for i in range(len(texts))] + try: + import tiktoken + + tokens = [] + indices = [] + encoding = tiktoken.model.encoding_for_model(self.document_model_name) + for i, text in enumerate(texts): + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + token = encoding.encode(text) + for j in range(0, len(token), self.embedding_ctx_length): + tokens += [token[j : j + self.embedding_ctx_length]] + indices += [i] + + batched_embeddings = [] + for i in range(0, len(tokens), chunk_size): + response = self.client.create( + input=tokens[i : i + chunk_size], engine=self.document_model_name + ) + batched_embeddings += [r["embedding"] for r in response["data"]] + + results: List[List[List[float]]] = [[] for i in range(len(texts))] + lens: List[List[int]] = [[] for i in range(len(texts))] + for i in range(len(indices)): + results[indices[i]].append(batched_embeddings[i]) + lens[indices[i]].append(len(batched_embeddings[i])) + + for i in range(len(texts)): + average = np.average(results[i], axis=0, weights=lens[i]) + embeddings[i] = (average / np.linalg.norm(average)).tolist() + + return embeddings + + except ImportError: + raise ValueError( + "Could not import tiktoken python package. " + "This is needed in order to for OpenAIEmbeddings. " + "Please it install it with `pip install tiktoken`." + ) + def _embedding_func(self, text: str, *, engine: str) -> List[float]: """Call out to OpenAI's embedding endpoint.""" # replace newlines, which can negatively affect performance. - text = text.replace("\n", " ") - return self.client.create(input=[text], engine=engine)["data"][0]["embedding"] + if self.embedding_ctx_length > 0: + return self._get_len_safe_embeddings([text], engine=engine)[0] + else: + text = text.replace("\n", " ") + return self.client.create(input=[text], engine=engine)["data"][0][ + "embedding" + ] def embed_documents( self, texts: List[str], chunk_size: int = 1000 @@ -89,13 +142,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ # handle large batches of texts - results = [] - for i in range(0, len(texts), chunk_size): - response = self.client.create( - input=texts[i : i + chunk_size], engine=self.document_model_name + if self.embedding_ctx_length > 0: + return self._get_len_safe_embeddings( + texts, engine=self.document_model_name, chunk_size=chunk_size ) - results += [r["embedding"] for r in response["data"]] - return results + else: + responses = [ + self._embedding_func(text, engine=self.document_model_name) + for text in texts + ] + return responses def embed_query(self, text: str) -> List[float]: """Call out to OpenAI's embedding endpoint for embedding query text. diff --git a/tests/integration_tests/embeddings/test_openai.py b/tests/integration_tests/embeddings/test_openai.py index 13beee61..3fdb1e53 100644 --- a/tests/integration_tests/embeddings/test_openai.py +++ b/tests/integration_tests/embeddings/test_openai.py @@ -15,6 +15,7 @@ def test_openai_embedding_documents_multiple() -> None: """Test openai embeddings.""" documents = ["foo bar", "bar foo", "foo"] embedding = OpenAIEmbeddings() + embedding.embedding_ctx_length = 8191 output = embedding.embed_documents(documents, chunk_size=2) assert len(output) == 3 assert len(output[0]) == 1536