You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openai-cookbook/examples/vector_databases/zilliz/Filtered_search_with_Zilliz...

328 lines
12 KiB
Plaintext

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Filtered Search with Zilliz and OpenAI\n",
"### Finding your next movie\n",
"\n",
"In this notebook we will be going over generating embeddings of movie descriptions with OpenAI and using those embeddings within Zilliz to find relevant movies. To narrow our search results and try something new, we are going to be using filtering to do metadata searches. The dataset in this example is sourced from HuggingFace datasets, and contains a little over 8 thousand movie entries.\n",
"\n",
"Lets begin by first downloading the required libraries for this notebook:\n",
"- `openai` is used for communicating with the OpenAI embedding service\n",
"- `pymilvus` is used for communicating with the Zilliz server\n",
"- `datasets` is used for downloading the dataset\n",
"- `tqdm` is used for the progress bars\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! pip install openai pymilvus datasets tqdm"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"To get Zilliz up and running take a look [here](https://zilliz.com/doc/quick_start). With your account and database set up, proceed to set the following values:\n",
"- URI: The URI your database is running on\n",
"- USER: Your database username\n",
"- PASSWORD: Your database password\n",
"- COLLECTION_NAME: What to name the collection within Zilliz\n",
"- DIMENSION: The dimension of the embeddings\n",
"- OPENAI_ENGINE: Which embedding model to use\n",
"- openai.api_key: Your OpenAI account key\n",
"- INDEX_PARAM: The index settings to use for the collection\n",
"- QUERY_PARAM: The search parameters to use\n",
"- BATCH_SIZE: How many texts to embed and insert at once"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"URI = 'your_uri'\n",
"TOKEN = 'your_token' # TOKEN == user:password or api_key\n",
"COLLECTION_NAME = 'book_search'\n",
"DIMENSION = 1536\n",
"OPENAI_ENGINE = 'text-embedding-3-small'\n",
"openai.api_key = 'sk-your_key'\n",
"\n",
"INDEX_PARAM = {\n",
" 'metric_type':'L2',\n",
" 'index_type':\"AUTOINDEX\",\n",
" 'params':{}\n",
"}\n",
"\n",
"QUERY_PARAM = {\n",
" \"metric_type\": \"L2\",\n",
" \"params\": {},\n",
"}\n",
"\n",
"BATCH_SIZE = 1000"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pymilvus import connections, utility, FieldSchema, Collection, CollectionSchema, DataType\n",
"\n",
"# Connect to Zilliz Database\n",
"connections.connect(uri=URI, token=TOKEN)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Remove collection if it already exists\n",
"if utility.has_collection(COLLECTION_NAME):\n",
" utility.drop_collection(COLLECTION_NAME)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# Create collection which includes the id, title, and embedding.\n",
"fields = [\n",
" FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
" FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=64000),\n",
" FieldSchema(name='type', dtype=DataType.VARCHAR, max_length=64000),\n",
" FieldSchema(name='release_year', dtype=DataType.INT64),\n",
" FieldSchema(name='rating', dtype=DataType.VARCHAR, max_length=64000),\n",
" FieldSchema(name='description', dtype=DataType.VARCHAR, max_length=64000),\n",
" FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)\n",
"]\n",
"schema = CollectionSchema(fields=fields)\n",
"collection = Collection(name=COLLECTION_NAME, schema=schema)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Create the index on the collection and load it.\n",
"collection.create_index(field_name=\"embedding\", index_params=INDEX_PARAM)\n",
"collection.load()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset\n",
"With Zilliz up and running we can begin grabbing our data. `Hugging Face Datasets` is a hub that holds many different user datasets, and for this example we are using HuggingLearners's netflix-shows dataset. This dataset contains movies and their metadata pairs for over 8 thousand movies. We are going to embed each description and store it within Zilliz along with its title, type, release_year and rating."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/filiphaltmayer/miniconda3/envs/haystack/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Found cached dataset csv (/Users/filiphaltmayer/.cache/huggingface/datasets/hugginglearners___csv/hugginglearners--netflix-shows-03475319fc65a05a/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)\n"
]
}
],
"source": [
"import datasets\n",
"\n",
"# Download the dataset \n",
"dataset = datasets.load_dataset('hugginglearners/netflix-shows', split='train')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Insert the Data\n",
"Now that we have our data on our machine we can begin embedding it and inserting it into Zilliz. The embedding function takes in text and returns the embeddings in a list format. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Simple function that converts the texts to embeddings\n",
"def embed(texts):\n",
" embeddings = openai.Embedding.create(\n",
" input=texts,\n",
" engine=OPENAI_ENGINE\n",
" )\n",
" return [x['embedding'] for x in embeddings['data']]\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This next step does the actual inserting. We iterate through all the entries and create batches that we insert once we hit our set batch size. After the loop is over we insert the last remaning batch if it exists. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 8807/8807 [00:54<00:00, 162.59it/s]\n"
]
}
],
"source": [
"from tqdm import tqdm\n",
"\n",
"data = [\n",
" [], # title\n",
" [], # type\n",
" [], # release_year\n",
" [], # rating\n",
" [], # description\n",
"]\n",
"\n",
"# Embed and insert in batches\n",
"for i in tqdm(range(0, len(dataset))):\n",
" data[0].append(dataset[i]['title'] or '')\n",
" data[1].append(dataset[i]['type'] or '')\n",
" data[2].append(dataset[i]['release_year'] or -1)\n",
" data[3].append(dataset[i]['rating'] or '')\n",
" data[4].append(dataset[i]['description'] or '')\n",
" if len(data[0]) % BATCH_SIZE == 0:\n",
" data.append(embed(data[4]))\n",
" collection.insert(data)\n",
" data = [[],[],[],[],[]]\n",
"\n",
"# Embed and insert the remainder \n",
"if len(data[0]) != 0:\n",
" data.append(embed(data[4]))\n",
" collection.insert(data)\n",
" data = [[],[],[],[],[]]\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Query the Database\n",
"With our data safely inserted into Zilliz, we can now perform a query. The query takes in a tuple of the movie description you are searching for and the filter to use. More info about the filter can be found [here](https://milvus.io/docs/boolean.md). The search first prints out your description and filter expression. After that for each result we print the score, title, type, release year, rating and description of the result movies. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Description: movie about a fluffly animal Expression: release_year < 2019 and rating like \"PG%\"\n",
"Results:\n",
"\tRank: 1 Score: 0.30085673928260803 Title: The Lamb\n",
"\t\tType: Movie Release Year: 2017 Rating: PG\n",
"A big-dreaming donkey escapes his menial existence and befriends some free-spirited\n",
"animal pals in this imaginative retelling of the Nativity Story.\n",
"\n",
"\tRank: 2 Score: 0.3352621793746948 Title: Puss in Boots\n",
"\t\tType: Movie Release Year: 2011 Rating: PG\n",
"The fabled feline heads to the Land of Giants with friends Humpty Dumpty and Kitty\n",
"Softpaws on a quest to nab its greatest treasure: the Golden Goose.\n",
"\n",
"\tRank: 3 Score: 0.3415083587169647 Title: Show Dogs\n",
"\t\tType: Movie Release Year: 2018 Rating: PG\n",
"A rough and tough police dog must go undercover with an FBI agent as a prim and proper\n",
"pet at a dog show to save a baby panda from an illegal sale.\n",
"\n",
"\tRank: 4 Score: 0.3428957462310791 Title: Open Season 2\n",
"\t\tType: Movie Release Year: 2008 Rating: PG\n",
"Elliot the buck and his forest-dwelling cohorts must rescue their dachshund pal from\n",
"some spoiled pets bent on returning him to domesticity.\n",
"\n",
"\tRank: 5 Score: 0.34376364946365356 Title: Stuart Little 2\n",
"\t\tType: Movie Release Year: 2002 Rating: PG\n",
"Zany misadventures are in store as lovable city mouse Stuart and his human brother,\n",
"George, raise the roof in this sequel to the 1999 blockbuster.\n",
"\n"
]
}
],
"source": [
"import textwrap\n",
"\n",
"def query(query, top_k = 5):\n",
" text, expr = query\n",
" res = collection.search(embed(text), anns_field='embedding', expr = expr, param=QUERY_PARAM, limit = top_k, output_fields=['title', 'type', 'release_year', 'rating', 'description'])\n",
" for i, hit in enumerate(res):\n",
" print('Description:', text, 'Expression:', expr)\n",
" print('Results:')\n",
" for ii, hits in enumerate(hit):\n",
" print('\\t' + 'Rank:', ii + 1, 'Score:', hits.score, 'Title:', hits.entity.get('title'))\n",
" print('\\t\\t' + 'Type:', hits.entity.get('type'), 'Release Year:', hits.entity.get('release_year'), 'Rating:', hits.entity.get('rating'))\n",
" print(textwrap.fill(hits.entity.get('description'), 88))\n",
" print()\n",
"\n",
"my_query = ('movie about a fluffly animal', 'release_year < 2019 and rating like \\\"PG%\\\"')\n",
"\n",
"query(my_query)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "haystack",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}