implemented batcher for openai embeddings

pull/417/head
Michael Yuan 1 year ago
parent 47046a1c86
commit 32cddbef25

@ -304,7 +304,7 @@
"from typing import List\n",
"\n",
"from openai.embeddings_utils import (\n",
" get_embedding,\n",
" get_embeddings,\n",
" distances_from_embeddings,\n",
" tsne_components_from_embeddings,\n",
" chart_from_components,\n",
@ -527,36 +527,62 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "852cff45",
"metadata": {},
"outputs": [],
"source": [
"# Use OpenAI get_embeddings batch requests to speed up embedding creation\n",
"def embeddings_batch_request(documents: pd.DataFrame):\n",
" records = documents.to_dict(\"records\")\n",
" print(\"Records to process: \", len(records))\n",
" product_vectors = []\n",
" docs = []\n",
" batchsize = 1000\n",
" \n",
" for idx,doc in enumerate(records,start=1):\n",
" # create byte vectors\n",
" docs.append(doc[\"product_text\"])\n",
" if idx % batchsize == 0:\n",
" product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n",
" docs.clear()\n",
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
" product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n",
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
" return product_vectors"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0d791186",
"metadata": {},
"outputs": [],
"source": [
"def index_documents(client: redis.Redis, prefix: str, documents: pd.DataFrame):\n",
" product_vectors = embeddings_batch_request(documents)\n",
" records = documents.to_dict(\"records\")\n",
" batchsize = 500\n",
" \n",
" # Use Redis pipelines to batch calls and save on round trip network communication\n",
" pipe = client.pipeline()\n",
" batch = 0\n",
" for doc in records:\n",
" for idx,doc in enumerate(records,start=1):\n",
" key = f\"{prefix}:{str(doc['product_id'])}\"\n",
"\n",
" # create byte vectors\n",
" text_embedding = np.array(get_embedding(doc[\"product_text\"], EMBEDDING_MODEL), dtype=np.float32).tobytes()\n",
"\n",
" text_embedding = np.array((product_vectors[idx-1]), dtype=np.float32).tobytes()\n",
" \n",
" # replace list of floats with byte vectors\n",
" doc[\"product_vector\"] = text_embedding\n",
"\n",
" pipe.hset(key, mapping = doc)\n",
" batch += 1\n",
" if batch == 500:\n",
" if idx % batchsize == 0:\n",
" pipe.execute()\n",
" batch = 0\n",
" pipe.execute()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "5bfaeafa",
"metadata": {},
"outputs": [
@ -564,12 +590,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded 1978 documents in Redis search index with name: product_embeddings\n"
"Records to process: 1978\n",
"Loaded 1978 documents in Redis search index with name: product_embeddings\n",
"CPU times: user 619 ms, sys: 78.9 ms, total: 698 ms\n",
"Wall time: 3.34 s\n"
]
}
],
"source": [
"# the styles_2k dataset should take 5-10min to load. the larger styles_40k dataset will take hours\n",
"%%time\n",
"index_documents(redis_client, PREFIX, df)\n",
"print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")"
]
@ -586,7 +615,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "b044aa93",
"metadata": {},
"outputs": [],
@ -629,7 +658,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "7e2025f6",
"metadata": {},
"outputs": [
@ -667,7 +696,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "2c81fbb7",
"metadata": {},
"outputs": [
@ -700,7 +729,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "8a56633b",
"metadata": {},
"outputs": [
@ -714,8 +743,8 @@
"3. Basics Men Blue Slim Fit Checked Shirt (Score: 0.627)\n",
"4. Basics Men Red Slim Fit Checked Shirt (Score: 0.623)\n",
"5. Basics Men Navy Slim Fit Checked Shirt (Score: 0.613)\n",
"6. Lee Rinse Navy Blue Slim Fit Jeans (Score: 0.559)\n",
"7. Tokyo Talkies Women Navy Slim Fit Jeans (Score: 0.553)\n"
"6. Lee Rinse Navy Blue Slim Fit Jeans (Score: 0.558)\n",
"7. Tokyo Talkies Women Navy Slim Fit Jeans (Score: 0.552)\n"
]
}
],
@ -731,7 +760,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "6c25ee8d",
"metadata": {},
"outputs": [
@ -747,8 +776,8 @@
"5. CASIO Youth Series Digital Men Black Small Dial Digital Watch W-210-1CVDF I065 (Score: 0.542)\n",
"6. Titan Women Silver Watch (Score: 0.542)\n",
"7. Police Men Black Dial Watch PL12778MSU-61 (Score: 0.541)\n",
"8. ADIDAS Original Men Black Dial Chronograph Watch ADH2641 (Score: 0.539)\n",
"9. Titan Raga Women Gold Watch (Score: 0.539)\n"
"8. Titan Raga Women Gold Watch (Score: 0.539)\n",
"9. ADIDAS Original Men Black Dial Chronograph Watch ADH2641 (Score: 0.539)\n"
]
}
],
@ -764,7 +793,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "2c0d11d8",
"metadata": {},
"outputs": [
@ -774,9 +803,9 @@
"text": [
"0. Enroute Teens Orange Sandals (Score: 0.701)\n",
"1. Fila Men Camper Brown Sandals (Score: 0.692)\n",
"2. Coolers Men Black Sandals (Score: 0.69)\n",
"2. Clarks Men Black Leather Closed Sandals (Score: 0.691)\n",
"3. Coolers Men Black Sandals (Score: 0.69)\n",
"4. Clarks Men Black Leather Closed Sandals (Score: 0.69)\n",
"4. Coolers Men Black Sandals (Score: 0.69)\n",
"5. Enroute Teens Brown Sandals (Score: 0.69)\n",
"6. Crocs Dora Boots Pink Sandals (Score: 0.69)\n",
"7. Enroute Men Leather Black Sandals (Score: 0.685)\n",
@ -797,7 +826,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"id": "7caad384",
"metadata": {},
"outputs": [
@ -830,7 +859,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"id": "f1232d3c",
"metadata": {},
"outputs": [
@ -839,17 +868,21 @@
"output_type": "stream",
"text": [
"0. Wrangler Men Leather Brown Belt (Score: 0.67)\n",
"1. Wrangler Women Black Belt (Score: 0.639)\n"
"1. Wrangler Women Black Belt (Score: 0.639)\n",
"2. Wrangler Men Green Striped Shirt (Score: 0.575)\n",
"3. Wrangler Men Purple Striped Shirt (Score: 0.549)\n",
"4. Wrangler Men Griffith White Shirt (Score: 0.543)\n",
"5. Wrangler Women Stella Green Shirt (Score: 0.542)\n"
]
}
],
"source": [
"# hybrid query for a brown belt filtering results by a year (NUMERIC) with a specific article type (TAG) and with a brand name (TEXT)\n",
"# hybrid query for a brown belt filtering results by a year (NUMERIC) with a specific article types (TAG) and with a brand name (TEXT)\n",
"results = search_redis(redis_client,\n",
" \"brown belt\",\n",
" vector_field=\"product_vector\",\n",
" k=10,\n",
" hybrid_fields='(@year:[2012 2012] @articleType:{Belts} @productDisplayName:\"Wrangler\")'\n",
" hybrid_fields='(@year:[2012 2012] @articleType:{Shirts | Belts} @productDisplayName:\"Wrangler\")'\n",
" )"
]
}

Loading…
Cancel
Save