Fixup OpenAI Embeddings - fix the weighted mean (#3778)

Re: https://github.com/hwchase17/langchain/issues/3777

Copy pasting from the issue:

While working on https://github.com/hwchase17/langchain/issues/3722 I
have noticed that there might be a bug in the current implementation of
the OpenAI length safe embeddings in `_get_len_safe_embeddings`, which
before https://github.com/hwchase17/langchain/issues/3722 was actually
the **default implementation** regardless of the length of the context
(via https://github.com/hwchase17/langchain/pull/2330).

It appears the weights used are constant and the length of the embedding
vector (1536) and NOT the number of tokens in the batch, as in the
reference implementation at
https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb

<hr>

Here's some debug info:

<img width="1094" alt="image"
src="https://user-images.githubusercontent.com/1419010/235286595-a8b55298-7830-45df-b9f7-d2a2ad0356e0.png">

<hr>

We can also validate this against the reference implementation:

<details>

<summary>Reference implementation (click to unroll)</summary>

This implementation is copy pasted from
https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb

```py
import openai
from itertools import islice
import numpy as np
from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type


EMBEDDING_MODEL = 'text-embedding-ada-002'
EMBEDDING_CTX_LENGTH = 8191
EMBEDDING_ENCODING = 'cl100k_base'

# let's make sure to not retry on an invalid request, because that is what we want to demonstrate
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError))
def get_embedding(text_or_tokens, model=EMBEDDING_MODEL):
    return openai.Embedding.create(input=text_or_tokens, model=model)["data"][0]["embedding"]

def batched(iterable, n):
    """Batch data into tuples of length n. The last batch may be shorter."""
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while (batch := tuple(islice(it, n))):
        yield batch
        
def chunked_tokens(text, encoding_name, chunk_length):
    encoding = tiktoken.get_encoding(encoding_name)
    tokens = encoding.encode(text)
    chunks_iterator = batched(tokens, chunk_length)
    yield from chunks_iterator


def reference_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True):
    chunk_embeddings = []
    chunk_lens = []
    for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens):
        chunk_embeddings.append(get_embedding(chunk, model=model))
        chunk_lens.append(len(chunk))

    if average:
        chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)
        chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings)  # normalizes length to 1
        chunk_embeddings = chunk_embeddings.tolist()
    return chunk_embeddings
```

</details>

```py
long_text = 'foo bar' * 5000

reference_safe_get_embedding(long_text, average=True)[:10]

# Here's the first 10 floats from the reference embeddings:
[0.004407593824276758,
 0.0017611146161865465,
 -0.019824815970984996,
 -0.02177626039794025,
 -0.012060967454897886,
 0.0017955296329155309,
 -0.015609168983609643,
 -0.012059823076681351,
 -0.016990468527792825,
 -0.004970484452089445]


# and now langchain implementation
from langchain.embeddings.openai import OpenAIEmbeddings
OpenAIEmbeddings().embed_query(long_text)[:10]

[0.003791506184693747,
 0.0025310066579390025,
 -0.019282322699514628,
 -0.021492679249899803,
 -0.012598522213242891,
 0.0022181168611315662,
 -0.015858940621301307,
 -0.011754004130791204,
 -0.016402944319627515,
 -0.004125287485127554]

# clearly they are different ^
```
pull/3920/head
Rafal Wojdyla 1 year ago committed by GitHub
parent 22a1896c30
commit 039b672f46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -150,7 +150,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
def _get_len_safe_embeddings(
self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None
) -> List[List[float]]:
embeddings: List[List[float]] = [[] for i in range(len(texts))]
embeddings: List[List[float]] = [[] for _ in range(len(texts))]
try:
import tiktoken
@ -180,10 +180,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
batched_embeddings += [r["embedding"] for r in response["data"]]
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
lens: List[List[int]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
results[indices[i]].append(batched_embeddings[i])
lens[indices[i]].append(len(batched_embeddings[i]))
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
for i in range(len(texts)):
_result = results[i]
@ -192,7 +192,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"data"
][0]["embedding"]
else:
average = np.average(_result, axis=0, weights=lens[i])
average = np.average(
_result, axis=0, weights=num_tokens_in_batch[i]
)
embeddings[i] = (average / np.linalg.norm(average)).tolist()
return embeddings

Loading…
Cancel
Save