implemented batcher for openai embeddings

pull/1077/head
Michael Yuan 1 year ago
parent 7be3b146cf
commit 6ce07ae71d

@ -304,7 +304,7 @@
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"from openai.embeddings_utils import (\n", "from openai.embeddings_utils import (\n",
" get_embedding,\n", " get_embeddings,\n",
" distances_from_embeddings,\n", " distances_from_embeddings,\n",
" tsne_components_from_embeddings,\n", " tsne_components_from_embeddings,\n",
" chart_from_components,\n", " chart_from_components,\n",
@ -527,36 +527,62 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
"id": "852cff45",
"metadata": {},
"outputs": [],
"source": [
"# Use OpenAI get_embeddings batch requests to speed up embedding creation\n",
"def embeddings_batch_request(documents: pd.DataFrame):\n",
" records = documents.to_dict(\"records\")\n",
" print(\"Records to process: \", len(records))\n",
" product_vectors = []\n",
" docs = []\n",
" batchsize = 1000\n",
" \n",
" for idx,doc in enumerate(records,start=1):\n",
" # create byte vectors\n",
" docs.append(doc[\"product_text\"])\n",
" if idx % batchsize == 0:\n",
" product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n",
" docs.clear()\n",
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
" product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n",
" print(\"Vectors processed \", len(product_vectors), end='\\r')\n",
" return product_vectors"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0d791186", "id": "0d791186",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def index_documents(client: redis.Redis, prefix: str, documents: pd.DataFrame):\n", "def index_documents(client: redis.Redis, prefix: str, documents: pd.DataFrame):\n",
" product_vectors = embeddings_batch_request(documents)\n",
" records = documents.to_dict(\"records\")\n", " records = documents.to_dict(\"records\")\n",
" batchsize = 500\n",
" \n", " \n",
" # Use Redis pipelines to batch calls and save on round trip network communication\n", " # Use Redis pipelines to batch calls and save on round trip network communication\n",
" pipe = client.pipeline()\n", " pipe = client.pipeline()\n",
" batch = 0\n", " for idx,doc in enumerate(records,start=1):\n",
" for doc in records:\n",
" key = f\"{prefix}:{str(doc['product_id'])}\"\n", " key = f\"{prefix}:{str(doc['product_id'])}\"\n",
"\n", "\n",
" # create byte vectors\n", " # create byte vectors\n",
" text_embedding = np.array(get_embedding(doc[\"product_text\"], EMBEDDING_MODEL), dtype=np.float32).tobytes()\n", " text_embedding = np.array((product_vectors[idx-1]), dtype=np.float32).tobytes()\n",
"\n", " \n",
" # replace list of floats with byte vectors\n", " # replace list of floats with byte vectors\n",
" doc[\"product_vector\"] = text_embedding\n", " doc[\"product_vector\"] = text_embedding\n",
"\n", "\n",
" pipe.hset(key, mapping = doc)\n", " pipe.hset(key, mapping = doc)\n",
" batch += 1\n", " if idx % batchsize == 0:\n",
" if batch == 500:\n",
" pipe.execute()\n", " pipe.execute()\n",
" batch = 0\n",
" pipe.execute()" " pipe.execute()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 12,
"id": "5bfaeafa", "id": "5bfaeafa",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -564,12 +590,15 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Loaded 1978 documents in Redis search index with name: product_embeddings\n" "Records to process: 1978\n",
"Loaded 1978 documents in Redis search index with name: product_embeddings\n",
"CPU times: user 619 ms, sys: 78.9 ms, total: 698 ms\n",
"Wall time: 3.34 s\n"
] ]
} }
], ],
"source": [ "source": [
"# the styles_2k dataset should take 5-10min to load. the larger styles_40k dataset will take hours\n", "%%time\n",
"index_documents(redis_client, PREFIX, df)\n", "index_documents(redis_client, PREFIX, df)\n",
"print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")" "print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")"
] ]
@ -586,7 +615,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 13,
"id": "b044aa93", "id": "b044aa93",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -629,7 +658,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 14,
"id": "7e2025f6", "id": "7e2025f6",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -667,7 +696,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 15,
"id": "2c81fbb7", "id": "2c81fbb7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -700,7 +729,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 16,
"id": "8a56633b", "id": "8a56633b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -714,8 +743,8 @@
"3. Basics Men Blue Slim Fit Checked Shirt (Score: 0.627)\n", "3. Basics Men Blue Slim Fit Checked Shirt (Score: 0.627)\n",
"4. Basics Men Red Slim Fit Checked Shirt (Score: 0.623)\n", "4. Basics Men Red Slim Fit Checked Shirt (Score: 0.623)\n",
"5. Basics Men Navy Slim Fit Checked Shirt (Score: 0.613)\n", "5. Basics Men Navy Slim Fit Checked Shirt (Score: 0.613)\n",
"6. Lee Rinse Navy Blue Slim Fit Jeans (Score: 0.559)\n", "6. Lee Rinse Navy Blue Slim Fit Jeans (Score: 0.558)\n",
"7. Tokyo Talkies Women Navy Slim Fit Jeans (Score: 0.553)\n" "7. Tokyo Talkies Women Navy Slim Fit Jeans (Score: 0.552)\n"
] ]
} }
], ],
@ -731,7 +760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 17,
"id": "6c25ee8d", "id": "6c25ee8d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -747,8 +776,8 @@
"5. CASIO Youth Series Digital Men Black Small Dial Digital Watch W-210-1CVDF I065 (Score: 0.542)\n", "5. CASIO Youth Series Digital Men Black Small Dial Digital Watch W-210-1CVDF I065 (Score: 0.542)\n",
"6. Titan Women Silver Watch (Score: 0.542)\n", "6. Titan Women Silver Watch (Score: 0.542)\n",
"7. Police Men Black Dial Watch PL12778MSU-61 (Score: 0.541)\n", "7. Police Men Black Dial Watch PL12778MSU-61 (Score: 0.541)\n",
"8. ADIDAS Original Men Black Dial Chronograph Watch ADH2641 (Score: 0.539)\n", "8. Titan Raga Women Gold Watch (Score: 0.539)\n",
"9. Titan Raga Women Gold Watch (Score: 0.539)\n" "9. ADIDAS Original Men Black Dial Chronograph Watch ADH2641 (Score: 0.539)\n"
] ]
} }
], ],
@ -764,7 +793,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 18,
"id": "2c0d11d8", "id": "2c0d11d8",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -774,9 +803,9 @@
"text": [ "text": [
"0. Enroute Teens Orange Sandals (Score: 0.701)\n", "0. Enroute Teens Orange Sandals (Score: 0.701)\n",
"1. Fila Men Camper Brown Sandals (Score: 0.692)\n", "1. Fila Men Camper Brown Sandals (Score: 0.692)\n",
"2. Coolers Men Black Sandals (Score: 0.69)\n", "2. Clarks Men Black Leather Closed Sandals (Score: 0.691)\n",
"3. Coolers Men Black Sandals (Score: 0.69)\n", "3. Coolers Men Black Sandals (Score: 0.69)\n",
"4. Clarks Men Black Leather Closed Sandals (Score: 0.69)\n", "4. Coolers Men Black Sandals (Score: 0.69)\n",
"5. Enroute Teens Brown Sandals (Score: 0.69)\n", "5. Enroute Teens Brown Sandals (Score: 0.69)\n",
"6. Crocs Dora Boots Pink Sandals (Score: 0.69)\n", "6. Crocs Dora Boots Pink Sandals (Score: 0.69)\n",
"7. Enroute Men Leather Black Sandals (Score: 0.685)\n", "7. Enroute Men Leather Black Sandals (Score: 0.685)\n",
@ -797,7 +826,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 19,
"id": "7caad384", "id": "7caad384",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -830,7 +859,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 20,
"id": "f1232d3c", "id": "f1232d3c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -839,17 +868,21 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"0. Wrangler Men Leather Brown Belt (Score: 0.67)\n", "0. Wrangler Men Leather Brown Belt (Score: 0.67)\n",
"1. Wrangler Women Black Belt (Score: 0.639)\n" "1. Wrangler Women Black Belt (Score: 0.639)\n",
"2. Wrangler Men Green Striped Shirt (Score: 0.575)\n",
"3. Wrangler Men Purple Striped Shirt (Score: 0.549)\n",
"4. Wrangler Men Griffith White Shirt (Score: 0.543)\n",
"5. Wrangler Women Stella Green Shirt (Score: 0.542)\n"
] ]
} }
], ],
"source": [ "source": [
"# hybrid query for a brown belt filtering results by a year (NUMERIC) with a specific article type (TAG) and with a brand name (TEXT)\n", "# hybrid query for a brown belt filtering results by a year (NUMERIC) with a specific article types (TAG) and with a brand name (TEXT)\n",
"results = search_redis(redis_client,\n", "results = search_redis(redis_client,\n",
" \"brown belt\",\n", " \"brown belt\",\n",
" vector_field=\"product_vector\",\n", " vector_field=\"product_vector\",\n",
" k=10,\n", " k=10,\n",
" hybrid_fields='(@year:[2012 2012] @articleType:{Belts} @productDisplayName:\"Wrangler\")'\n", " hybrid_fields='(@year:[2012 2012] @articleType:{Shirts | Belts} @productDisplayName:\"Wrangler\")'\n",
" )" " )"
] ]
} }

Loading…
Cancel
Save