Fixed openai embeddings to be safe by batching them based on token size calculation. (#991)

I modified the logic of the batch calculation for embedding according to
this cookbook

https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
pull/1080/head
Hasegawa Yuya 2 years ago committed by GitHub
parent f0a258555b
commit e08961ab25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""Wrapper around OpenAI embedding models.""" """Wrapper around OpenAI embedding models."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy as np
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -24,6 +25,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private: client: Any #: :meta private:
document_model_name: str = "text-embedding-ada-002" document_model_name: str = "text-embedding-ada-002"
query_model_name: str = "text-embedding-ada-002" query_model_name: str = "text-embedding-ada-002"
embedding_ctx_length: int = -1
openai_api_key: Optional[str] = None openai_api_key: Optional[str] = None
class Config: class Config:
@ -69,11 +71,62 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
) )
return values return values
# please refer to
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: int = 1000
) -> List[List[float]]:
embeddings: List[List[float]] = [[] for i in range(len(texts))]
try:
import tiktoken
tokens = []
indices = []
encoding = tiktoken.model.encoding_for_model(self.document_model_name)
for i, text in enumerate(texts):
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
token = encoding.encode(text)
for j in range(0, len(token), self.embedding_ctx_length):
tokens += [token[j : j + self.embedding_ctx_length]]
indices += [i]
batched_embeddings = []
for i in range(0, len(tokens), chunk_size):
response = self.client.create(
input=tokens[i : i + chunk_size], engine=self.document_model_name
)
batched_embeddings += [r["embedding"] for r in response["data"]]
results: List[List[List[float]]] = [[] for i in range(len(texts))]
lens: List[List[int]] = [[] for i in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
lens[indices[i]].append(len(batched_embeddings[i]))
for i in range(len(texts)):
average = np.average(results[i], axis=0, weights=lens[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for OpenAIEmbeddings. "
"Please it install it with `pip install tiktoken`."
)
def _embedding_func(self, text: str, *, engine: str) -> List[float]: def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint.""" """Call out to OpenAI's embedding endpoint."""
# replace newlines, which can negatively affect performance. # replace newlines, which can negatively affect performance.
text = text.replace("\n", " ") if self.embedding_ctx_length > 0:
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"] return self._get_len_safe_embeddings([text], engine=engine)[0]
else:
text = text.replace("\n", " ")
return self.client.create(input=[text], engine=engine)["data"][0][
"embedding"
]
def embed_documents( def embed_documents(
self, texts: List[str], chunk_size: int = 1000 self, texts: List[str], chunk_size: int = 1000
@ -89,13 +142,16 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text. List of embeddings, one for each text.
""" """
# handle large batches of texts # handle large batches of texts
results = [] if self.embedding_ctx_length > 0:
for i in range(0, len(texts), chunk_size): return self._get_len_safe_embeddings(
response = self.client.create( texts, engine=self.document_model_name, chunk_size=chunk_size
input=texts[i : i + chunk_size], engine=self.document_model_name
) )
results += [r["embedding"] for r in response["data"]] else:
return results responses = [
self._embedding_func(text, engine=self.document_model_name)
for text in texts
]
return responses
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text. """Call out to OpenAI's embedding endpoint for embedding query text.

@ -15,6 +15,7 @@ def test_openai_embedding_documents_multiple() -> None:
"""Test openai embeddings.""" """Test openai embeddings."""
documents = ["foo bar", "bar foo", "foo"] documents = ["foo bar", "bar foo", "foo"]
embedding = OpenAIEmbeddings() embedding = OpenAIEmbeddings()
embedding.embedding_ctx_length = 8191
output = embedding.embed_documents(documents, chunk_size=2) output = embedding.embed_documents(documents, chunk_size=2)
assert len(output) == 3 assert len(output) == 3
assert len(output[0]) == 1536 assert len(output[0]) == 1536

Loading…
Cancel
Save