Fix weighted average

pull/114/head
Filipe de Avila Belbute Peres 1 year ago
parent 9f4bb5260a
commit d339859248

@ -200,11 +200,13 @@
"\n",
"def len_safe_get_embedding(text, model=EMBEDDING_MODEL, max_tokens=EMBEDDING_CTX_LENGTH, encoding_name=EMBEDDING_ENCODING, average=True):\n",
" chunk_embeddings = []\n",
" chunk_lens = []\n",
" for chunk in chunked_tokens(text, encoding_name=encoding_name, chunk_length=max_tokens):\n",
" chunk_embeddings.append(get_embedding(chunk, model=model))\n",
" chunk_lens.append(len(chunk))\n",
"\n",
" if average:\n",
" chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=[len(c) for c in chunk_embeddings])\n",
" chunk_embeddings = np.average(chunk_embeddings, axis=0, weights=chunk_lens)\n",
" chunk_embeddings = chunk_embeddings / np.linalg.norm(chunk_embeddings) # normalizes length to 1\n",
" chunk_embeddings = chunk_embeddings.tolist()\n",
" return chunk_embeddings"

Loading…
Cancel
Save