Extend opensearch to better support existing instances (#2500) (#2509)

Closes #2500.
This commit is contained in:
Sam Weaver 2023-04-06 15:45:56 -04:00 committed by GitHub
parent ad87584c35
commit 2ffb90b161
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 9 deletions

View File

@ -175,7 +175,7 @@
"docsearch = OpenSearchVectorSearch.from_texts(texts, embeddings, opensearch_url=\"http://localhost:9200\", is_appx_search=False)\n", "docsearch = OpenSearchVectorSearch.from_texts(texts, embeddings, opensearch_url=\"http://localhost:9200\", is_appx_search=False)\n",
"filter = {\"bool\": {\"filter\": {\"term\": {\"text\": \"smuggling\"}}}}\n", "filter = {\"bool\": {\"filter\": {\"term\": {\"text\": \"smuggling\"}}}}\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.similarity_search(\"What did the president say about Ketanji Brown Jackson\", search_type=\"painless_scripting\", space_type=\"cosineSimilarity\", pre_filter=filter)" "docs = docsearch.similarity_search(\"What did the president say about Ketanji Brown Jackson\", search_type=\"painless_scripting\", space_type=\"cosinesimil\", pre_filter=filter)"
] ]
}, },
{ {
@ -191,6 +191,30 @@
"source": [ "source": [
"print(docs[0].page_content)" "print(docs[0].page_content)"
] ]
},
{
"cell_type": "markdown",
"id": "73264864",
"metadata": {},
"source": [
"#### Using a preexisting OpenSearch instance\n",
"\n",
"It's also possible to use a preexisting OpenSearch instance with documents that already have vectors present."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82a23440",
"metadata": {},
"outputs": [],
"source": [
"# this is just an example, you would need to change these values to point to another opensearch instance\n",
"docsearch = OpenSearchVectorSearch(index_name=\"index-*\", embedding_function=embeddings, opensearch_url=\"http://localhost:9200\")\n",
"\n",
"# you can specify custom field names to match the fields you're using to store your embedding, document text value, and metadata\n",
"docs = docsearch.similarity_search(\"Who was asking about getting lunch today?\", search_type=\"script_scoring\", space_type=\"cosinesimil\", vector_field=\"message_embedding\", text_field=\"message\", metadata_field=\"message_metadata\")"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -128,12 +128,15 @@ def _default_text_mapping(
def _default_approximate_search_query( def _default_approximate_search_query(
query_vector: List[float], size: int = 4, k: int = 4 query_vector: List[float],
size: int = 4,
k: int = 4,
vector_field: str = "vector_field",
) -> Dict: ) -> Dict:
"""For Approximate k-NN Search, this is the default query.""" """For Approximate k-NN Search, this is the default query."""
return { return {
"size": size, "size": size,
"query": {"knn": {"vector_field": {"vector": query_vector, "k": k}}}, "query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
} }
@ -141,6 +144,7 @@ def _default_script_query(
query_vector: List[float], query_vector: List[float],
space_type: str = "l2", space_type: str = "l2",
pre_filter: Dict = MATCH_ALL_QUERY, pre_filter: Dict = MATCH_ALL_QUERY,
vector_field: str = "vector_field",
) -> Dict: ) -> Dict:
"""For Script Scoring Search, this is the default query.""" """For Script Scoring Search, this is the default query."""
return { return {
@ -151,7 +155,7 @@ def _default_script_query(
"source": "knn_score", "source": "knn_score",
"lang": "knn", "lang": "knn",
"params": { "params": {
"field": "vector_field", "field": vector_field,
"query_value": query_vector, "query_value": query_vector,
"space_type": space_type, "space_type": space_type,
}, },
@ -176,6 +180,7 @@ def _default_painless_scripting_query(
query_vector: List[float], query_vector: List[float],
space_type: str = "l2Squared", space_type: str = "l2Squared",
pre_filter: Dict = MATCH_ALL_QUERY, pre_filter: Dict = MATCH_ALL_QUERY,
vector_field: str = "vector_field",
) -> Dict: ) -> Dict:
"""For Painless Scripting Search, this is the default query.""" """For Painless Scripting Search, this is the default query."""
source = __get_painless_scripting_source(space_type, query_vector) source = __get_painless_scripting_source(space_type, query_vector)
@ -186,7 +191,7 @@ def _default_painless_scripting_query(
"script": { "script": {
"source": source, "source": source,
"params": { "params": {
"field": "vector_field", "field": vector_field,
"query_value": query_vector, "query_value": query_vector,
}, },
}, },
@ -269,6 +274,15 @@ class OpenSearchVectorSearch(VectorStore):
Returns: Returns:
List of Documents most similar to the query. List of Documents most similar to the query.
Optional Args:
vector_field: Document field embeddings are stored in. Defaults to
"vector_field".
text_field: Document field the text of the document is stored in. Defaults
to "text".
metadata_field: Document field that metadata is stored in. Defaults to
"metadata".
Can be set to a special value "*" to include the entire document.
Optional Args for Approximate Search: Optional Args for Approximate Search:
search_type: "approximate_search"; default: "approximate_search" search_type: "approximate_search"; default: "approximate_search"
size: number of results the query actually returns; default: 4 size: number of results the query actually returns; default: 4
@ -291,18 +305,27 @@ class OpenSearchVectorSearch(VectorStore):
""" """
embedding = self.embedding_function.embed_query(query) embedding = self.embedding_function.embed_query(query)
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search") search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
text_field = _get_kwargs_value(kwargs, "text_field", "text")
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
if search_type == "approximate_search": if search_type == "approximate_search":
size = _get_kwargs_value(kwargs, "size", 4) size = _get_kwargs_value(kwargs, "size", 4)
search_query = _default_approximate_search_query(embedding, size, k) vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
search_query = _default_approximate_search_query(
embedding, size, k, vector_field
)
elif search_type == SCRIPT_SCORING_SEARCH: elif search_type == SCRIPT_SCORING_SEARCH:
space_type = _get_kwargs_value(kwargs, "space_type", "l2") space_type = _get_kwargs_value(kwargs, "space_type", "l2")
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
search_query = _default_script_query(embedding, space_type, pre_filter) vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
search_query = _default_script_query(
embedding, space_type, pre_filter, vector_field
)
elif search_type == PAINLESS_SCRIPTING_SEARCH: elif search_type == PAINLESS_SCRIPTING_SEARCH:
space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared") space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared")
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
search_query = _default_painless_scripting_query( search_query = _default_painless_scripting_query(
embedding, space_type, pre_filter embedding, space_type, pre_filter, vector_field
) )
else: else:
raise ValueError("Invalid `search_type` provided as an argument") raise ValueError("Invalid `search_type` provided as an argument")
@ -310,7 +333,13 @@ class OpenSearchVectorSearch(VectorStore):
response = self.client.search(index=self.index_name, body=search_query) response = self.client.search(index=self.index_name, body=search_query)
hits = [hit["_source"] for hit in response["hits"]["hits"][:k]] hits = [hit["_source"] for hit in response["hits"]["hits"][:k]]
documents = [ documents = [
Document(page_content=hit["text"], metadata=hit["metadata"]) for hit in hits Document(
page_content=hit[text_field],
metadata=hit
if metadata_field == "*" or metadata_field not in hit
else hit[metadata_field],
)
for hit in hits
] ]
return documents return documents