diff --git a/examples/vector_databases/redis/redis-hybrid-query-examples.ipynb b/examples/vector_databases/redis/redis-hybrid-query-examples.ipynb index 352e1c40..0ebf38a8 100644 --- a/examples/vector_databases/redis/redis-hybrid-query-examples.ipynb +++ b/examples/vector_databases/redis/redis-hybrid-query-examples.ipynb @@ -304,7 +304,7 @@ "from typing import List\n", "\n", "from openai.embeddings_utils import (\n", - " get_embedding,\n", + " get_embeddings,\n", " distances_from_embeddings,\n", " tsne_components_from_embeddings,\n", " chart_from_components,\n", @@ -527,36 +527,62 @@ { "cell_type": "code", "execution_count": 10, + "id": "852cff45", + "metadata": {}, + "outputs": [], + "source": [ + "# Use OpenAI get_embeddings batch requests to speed up embedding creation\n", + "def embeddings_batch_request(documents: pd.DataFrame):\n", + " records = documents.to_dict(\"records\")\n", + " print(\"Records to process: \", len(records))\n", + " product_vectors = []\n", + " docs = []\n", + " batchsize = 1000\n", + " \n", + " for idx,doc in enumerate(records,start=1):\n", + " # create byte vectors\n", + " docs.append(doc[\"product_text\"])\n", + " if idx % batchsize == 0:\n", + " product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n", + " docs.clear()\n", + " print(\"Vectors processed \", len(product_vectors), end='\\r')\n", + " product_vectors += get_embeddings(docs, EMBEDDING_MODEL)\n", + " print(\"Vectors processed \", len(product_vectors), end='\\r')\n", + " return product_vectors" + ] + }, + { + "cell_type": "code", + "execution_count": 11, "id": "0d791186", "metadata": {}, "outputs": [], "source": [ "def index_documents(client: redis.Redis, prefix: str, documents: pd.DataFrame):\n", + " product_vectors = embeddings_batch_request(documents)\n", " records = documents.to_dict(\"records\")\n", + " batchsize = 500\n", " \n", " # Use Redis pipelines to batch calls and save on round trip network communication\n", " pipe = client.pipeline()\n", - " batch = 0\n", - " for doc in records:\n", + " for idx,doc in enumerate(records,start=1):\n", " key = f\"{prefix}:{str(doc['product_id'])}\"\n", "\n", " # create byte vectors\n", - " text_embedding = np.array(get_embedding(doc[\"product_text\"], EMBEDDING_MODEL), dtype=np.float32).tobytes()\n", - "\n", + " text_embedding = np.array((product_vectors[idx-1]), dtype=np.float32).tobytes()\n", + " \n", " # replace list of floats with byte vectors\n", " doc[\"product_vector\"] = text_embedding\n", "\n", " pipe.hset(key, mapping = doc)\n", - " batch += 1\n", - " if batch == 500:\n", + " if idx % batchsize == 0:\n", " pipe.execute()\n", - " batch = 0\n", " pipe.execute()" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "5bfaeafa", "metadata": {}, "outputs": [ @@ -564,12 +590,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "Loaded 1978 documents in Redis search index with name: product_embeddings\n" + "Records to process: 1978\n", + "Loaded 1978 documents in Redis search index with name: product_embeddings\n", + "CPU times: user 619 ms, sys: 78.9 ms, total: 698 ms\n", + "Wall time: 3.34 s\n" ] } ], "source": [ - "# the styles_2k dataset should take 5-10min to load. the larger styles_40k dataset will take hours\n", + "%%time\n", "index_documents(redis_client, PREFIX, df)\n", "print(f\"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}\")" ] @@ -586,7 +615,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "b044aa93", "metadata": {}, "outputs": [], @@ -629,7 +658,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "7e2025f6", "metadata": {}, "outputs": [ @@ -667,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "2c81fbb7", "metadata": {}, "outputs": [ @@ -700,7 +729,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "8a56633b", "metadata": {}, "outputs": [ @@ -714,8 +743,8 @@ "3. Basics Men Blue Slim Fit Checked Shirt (Score: 0.627)\n", "4. Basics Men Red Slim Fit Checked Shirt (Score: 0.623)\n", "5. Basics Men Navy Slim Fit Checked Shirt (Score: 0.613)\n", - "6. Lee Rinse Navy Blue Slim Fit Jeans (Score: 0.559)\n", - "7. Tokyo Talkies Women Navy Slim Fit Jeans (Score: 0.553)\n" + "6. Lee Rinse Navy Blue Slim Fit Jeans (Score: 0.558)\n", + "7. Tokyo Talkies Women Navy Slim Fit Jeans (Score: 0.552)\n" ] } ], @@ -731,7 +760,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "id": "6c25ee8d", "metadata": {}, "outputs": [ @@ -747,8 +776,8 @@ "5. CASIO Youth Series Digital Men Black Small Dial Digital Watch W-210-1CVDF I065 (Score: 0.542)\n", "6. Titan Women Silver Watch (Score: 0.542)\n", "7. Police Men Black Dial Watch PL12778MSU-61 (Score: 0.541)\n", - "8. ADIDAS Original Men Black Dial Chronograph Watch ADH2641 (Score: 0.539)\n", - "9. Titan Raga Women Gold Watch (Score: 0.539)\n" + "8. Titan Raga Women Gold Watch (Score: 0.539)\n", + "9. ADIDAS Original Men Black Dial Chronograph Watch ADH2641 (Score: 0.539)\n" ] } ], @@ -764,7 +793,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "id": "2c0d11d8", "metadata": {}, "outputs": [ @@ -774,9 +803,9 @@ "text": [ "0. Enroute Teens Orange Sandals (Score: 0.701)\n", "1. Fila Men Camper Brown Sandals (Score: 0.692)\n", - "2. Coolers Men Black Sandals (Score: 0.69)\n", + "2. Clarks Men Black Leather Closed Sandals (Score: 0.691)\n", "3. Coolers Men Black Sandals (Score: 0.69)\n", - "4. Clarks Men Black Leather Closed Sandals (Score: 0.69)\n", + "4. Coolers Men Black Sandals (Score: 0.69)\n", "5. Enroute Teens Brown Sandals (Score: 0.69)\n", "6. Crocs Dora Boots Pink Sandals (Score: 0.69)\n", "7. Enroute Men Leather Black Sandals (Score: 0.685)\n", @@ -797,7 +826,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "id": "7caad384", "metadata": {}, "outputs": [ @@ -830,7 +859,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "id": "f1232d3c", "metadata": {}, "outputs": [ @@ -839,17 +868,21 @@ "output_type": "stream", "text": [ "0. Wrangler Men Leather Brown Belt (Score: 0.67)\n", - "1. Wrangler Women Black Belt (Score: 0.639)\n" + "1. Wrangler Women Black Belt (Score: 0.639)\n", + "2. Wrangler Men Green Striped Shirt (Score: 0.575)\n", + "3. Wrangler Men Purple Striped Shirt (Score: 0.549)\n", + "4. Wrangler Men Griffith White Shirt (Score: 0.543)\n", + "5. Wrangler Women Stella Green Shirt (Score: 0.542)\n" ] } ], "source": [ - "# hybrid query for a brown belt filtering results by a year (NUMERIC) with a specific article type (TAG) and with a brand name (TEXT)\n", + "# hybrid query for a brown belt filtering results by a year (NUMERIC) with a specific article types (TAG) and with a brand name (TEXT)\n", "results = search_redis(redis_client,\n", " \"brown belt\",\n", " vector_field=\"product_vector\",\n", " k=10,\n", - " hybrid_fields='(@year:[2012 2012] @articleType:{Belts} @productDisplayName:\"Wrangler\")'\n", + " hybrid_fields='(@year:[2012 2012] @articleType:{Shirts | Belts} @productDisplayName:\"Wrangler\")'\n", " )" ] }