mirror of https://github.com/hwchase17/langchain
Fixup OpenAI Embeddings - fix the weighted mean (#3778)
Re: https://github.com/hwchase17/langchain/issues/3777 Copy pasting from the issue: While working on https://github.com/hwchase17/langchain/issues/3722 I have noticed that there might be a bug in the current implementation of the OpenAI length safe embeddings in `_get_len_safe_embeddings`, which before https://github.com/hwchase17/langchain/issues/3722 was actually the **default implementation** regardless of the length of the context (via https://github.com/hwchase17/langchain/pull/2330). It appears the weights used are constant and the length of the embedding vector (1536) and NOT the number of tokens in the batch, as in the reference implementation at https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb <hr> Here's some debug info: <img width="1094" alt="image" src="https://user-images.githubusercontent.com/1419010/235286595-a8b55298-7830-45df-b9f7-d2a2ad0356e0.png"> <hr> We can also validate this against the reference implementation: <details> <summary>Reference implementation (click to unroll)</summary> This implementation is copy pasted from https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb ```py import openai from itertools import islice import numpy as np from tenacity import retry, wait_random_exponential, stop_after_attempt, retry_if_not_exception_type EMBEDDING_MODEL = 'text-embedding-ada-002' EMBEDDING_CTX_LENGTH = 8191 EMBEDDING_ENCODING = 'cl100k_base' # let's make sure to not retry on an invalid request, because that is what we want to demonstrate @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(openai.InvalidRequestError)) def get_embedding(text_or_tokens, model=EMBEDDING_MODEL): return openai.Embedding.create(input=text_or_tokens, model=model)["data"][0]["embedding"] def batched(iterable, n): """Batch data into tuples of length n. The last batch may be shorter.""" # batched('ABCDEFG', 3) --> ABC DEF G if n < 1: raise ValueError('n must be at least one') it = iter(iterable) while (batch := tuple(islice(it, n))): yield batch def chunked_tokens(text, encoding_name, chunk_length): encoding = tiktoken.get_encoding(encoding_name) tokens = encoding.encode(text) chunks_iterator = batched(tokens, chunk_length) yield from chunks_iterator def reference_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True): chunk_embeddings = [] chunk_lens = [] for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens): chunk_embeddings.append(get_embedding(chunk, model=model)) chunk_lens.append(len(chunk)) if average: chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens) chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) # normalizes length to 1 chunk_embeddings = chunk_embeddings.tolist() return chunk_embeddings ``` </details> ```py long_text = 'foo bar' * 5000 reference_safe_get_embedding(long_text, average=True)[:10] # Here's the first 10 floats from the reference embeddings: [0.004407593824276758, 0.0017611146161865465, -0.019824815970984996, -0.02177626039794025, -0.012060967454897886, 0.0017955296329155309, -0.015609168983609643, -0.012059823076681351, -0.016990468527792825, -0.004970484452089445] # and now langchain implementation from langchain.embeddings.openai import OpenAIEmbeddings OpenAIEmbeddings().embed_query(long_text)[:10] [0.003791506184693747, 0.0025310066579390025, -0.019282322699514628, -0.021492679249899803, -0.012598522213242891, 0.0022181168611315662, -0.015858940621301307, -0.011754004130791204, -0.016402944319627515, -0.004125287485127554] # clearly they are different ^ ```pull/3920/head
parent
22a1896c30
commit
039b672f46
Loading…
Reference in New Issue