From 039b672f461fc2c3d64c94062e3355d459c59318 Mon Sep 17 00:00:00 2001 From: Rafal Wojdyla Date: Mon, 1 May 2023 23:47:38 +0100 Subject: [PATCH] Fixup OpenAI Embeddings - fix the weighted mean (#3778) Re: https://github.com/hwchase17/langchain/issues/3777 Copy pasting from the issue: While working on https://github.com/hwchase17/langchain/issues/3722 I have noticed that there might be a bug in the current implementation of the OpenAI length safe embeddings in `_get_len_safe_embeddings`, which before https://github.com/hwchase17/langchain/issues/3722 was actually the **default implementation** regardless of the length of the context (via https://github.com/hwchase17/langchain/pull/2330). It appears the weights used are constant and the length of the embedding vector (1536) and NOT the number of tokens in the batch, as in the reference implementation at https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
Here's some debug info: image
We can also validate this against the reference implementation:
Reference implementation (click to unroll) This implementation is copy pasted from https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb ```py import openai from itertools import islice import numpy as np from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type EMBEDDING_MODEL = 'text-embedding-ada-002' EMBEDDING_CTX_LENGTH = 8191 EMBEDDING_ENCODING = 'cl100k_base' # let's make sure to not retry on an invalid request, because that is what we want to demonstrate @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError)) def get_embedding(text_or_tokens, model=EMBEDDING_MODEL): return openai.Embedding.create(input=text_or_tokens, model=model)["data"][0]["embedding"] def batched(iterable, n): """Batch data into tuples of length n. The last batch may be shorter.""" # batched('ABCDEFG', 3) --> ABC DEF G if n < 1: raise ValueError('n must be at least one') it = iter(iterable) while (batch := tuple(islice(it, n))): yield batch def chunked_tokens(text, encoding_name, chunk_length): encoding = tiktoken.get_encoding(encoding_name) tokens = encoding.encode(text) chunks_iterator = batched(tokens, chunk_length) yield from chunks_iterator def reference_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True): chunk_embeddings = [] chunk_lens = [] for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens): chunk_embeddings.append(get_embedding(chunk, model=model)) chunk_lens.append(len(chunk)) if average: chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) # normalizes length to 1 chunk_embeddings = chunk_embeddings.tolist() return chunk_embeddings ```
```py long_text = 'foo bar' * 5000 reference_safe_get_embedding(long_text, average=True)[:10] # Here's the first 10 floats from the reference embeddings: [0.004407593824276758, 0.0017611146161865465, -0.019824815970984996, -0.02177626039794025, -0.012060967454897886, 0.0017955296329155309, -0.015609168983609643, -0.012059823076681351, -0.016990468527792825, -0.004970484452089445] # and now langchain implementation from langchain.embeddings.openai import OpenAIEmbeddings OpenAIEmbeddings().embed_query(long_text)[:10] [0.003791506184693747, 0.0025310066579390025, -0.019282322699514628, -0.021492679249899803, -0.012598522213242891, 0.0022181168611315662, -0.015858940621301307, -0.011754004130791204, -0.016402944319627515, -0.004125287485127554] # clearly they are different ^ ``` --- langchain/embeddings/openai.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 7caa304fb6..c10ffc6040 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -150,7 +150,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): def _get_len_safe_embeddings( self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None ) -> List[List[float]]: - embeddings: List[List[float]] = [[] for i in range(len(texts))] + embeddings: List[List[float]] = [[] for _ in range(len(texts))] try: import tiktoken @@ -180,10 +180,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings): batched_embeddings += [r["embedding"] for r in response["data"]] results: List[List[List[float]]] = [[] for _ in range(len(texts))] - lens: List[List[int]] = [[] for _ in range(len(texts))] + num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): results[indices[i]].append(batched_embeddings[i]) - lens[indices[i]].append(len(batched_embeddings[i])) + num_tokens_in_batch[indices[i]].append(len(tokens[i])) for i in range(len(texts)): _result = results[i] @@ -192,7 +192,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): "data" ][0]["embedding"] else: - average = np.average(_result, axis=0, weights=lens[i]) + average = np.average( + _result, axis=0, weights=num_tokens_in_batch[i] + ) embeddings[i] = (average / np.linalg.norm(average)).tolist() return embeddings