Add support for structured data sources with google enterprise search (#9037)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
- Description: Added the capability to handles structured data from
google enterprise search,
- Issue: Retriever failed when underline search engine was integrated
with structured data,
  - Dependencies: google-api-core
  - Tag maintainer: @jarokaz
  - Twitter handle: anifort

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Christos Aniftos <aniftos@google.com>
Co-authored-by: Holt Skinner <13262395+holtskinner@users.noreply.github.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
anifort 2023-08-23 04:18:10 +01:00 committed by GitHub
parent 02545a54b3
commit 900c1f3e8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 27 deletions

View File

@ -100,8 +100,12 @@
"source": [
"## Configure and use the Enterprise Search retriever\n",
"\n",
"The Enterprise Search retriever is implemented in the `langchain.retriever.GoogleCloudEntepriseSearchRetriever` class. The `get_relevan_documents` method returns a list of `langchain.schema.Document` documents where the `page_content` field of each document is populated with either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of a document from which the segments or answers were extracted.\n",
"The Enterprise Search retriever is implemented in the `langchain.retriever.GoogleCloudEntepriseSearchRetriever` class. The `get_relevant_documents` method returns a list of `langchain.schema.Document` documents where the `page_content` field of each document is populated the document content.\n",
"Depending on the data type used in Enterprise search (structured or unstructured) the `page_content` field is populated as follows:\n",
"- Structured data source: either an `extractive segment` or an `extractive answer` that matches a query. The `metadata` field is populated with metadata (if any) of the document from which the segments or answers were extracted.\n",
"- Unstructured data source: a string json containing all the fields returned from the structured data source. The `metadata` field is populated with metadata (if any) of the document \n",
"\n",
"### Only for Unstructured data sources:\n",
"An extractive answer is verbatim text that is returned with each search result. It is extracted directly from the original document. Extractive answers are typically displayed near the top of web pages to provide an end user with a brief answer that is contextually relevant to their query. Extractive answers are available for website and unstructured search.\n",
"\n",
"An extractive segment is verbatim text that is returned with each search result. An extractive segment is usually more verbose than an extractive answer. Extractive segments can be displayed as an answer to a query, and can be used to perform post-processing tasks and as input for large language models to generate answers or new text. Extractive segments are available for unstructured search.\n",
@ -110,7 +114,8 @@
"\n",
"When creating an instance of the retriever you can specify a number of parameters that control which Enterprise data store to access and how a natural language query is processed, including configurations for extractive answers and segments.\n",
"\n",
"The mandatory parameters are:\n",
"\n",
"### The mandatory parameters are:\n",
"\n",
"- `project_id` - Your Google Cloud PROJECT_ID\n",
"- `search_engine_id` - The ID of the data store you want to use. \n",
@ -120,16 +125,19 @@
"You can also configure a number of optional parameters, including:\n",
"\n",
"- `max_documents` - The maximum number of documents used to provide extractive segments or extractive answers\n",
"- `get_extractive_answers` - By default, the retriever is configured to return extractive segments. Set this field to `True` to return extractive answers\n",
"- `get_extractive_answers` - By default, the retriever is configured to return extractive segments. Set this field to `True` to return extractive answers. This is used only when `engine_data_type` set to 0 (unstructured) \n",
"- `max_extractive_answer_count` - The maximum number of extractive answers returned in each search result.\n",
" At most 5 answers will be returned\n",
" At most 5 answers will be returned. This is used only when `engine_data_type` set to 0 (unstructured) \n",
"- `max_extractive_segment_count` - The maximum number of extractive segments returned in each search result.\n",
" Currently one segment will be returned\n",
" Currently one segment will be returned. This is used only when `engine_data_type` set to 0 (unstructured) \n",
"- `filter` - The filter expression that allows you filter the search results based on the metadata associated with the documents in the searched data store. \n",
"- `query_expansion_condition` - Specification to determine under which conditions query expansion should occur.\n",
" 0 - Unspecified query expansion condition. In this case, server behavior defaults to disabled.\n",
" 1 - Disabled query expansion. Only the exact search query is used, even if SearchResponse.total_size is zero.\n",
" 2 - Automatic query expansion built by the Search API.\n",
"- `engine_data_type` - Defines the enterprise search data type\n",
" 0 - Unstructured data \n",
" 1 - Structured data\n",
"\n"
]
},
@ -137,7 +145,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retriever with extractve segments"
"### Configure and use the retriever for **unstructured** data with extractve segments "
]
},
{
@ -182,7 +190,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retriever with extractve answers "
"### Configure and use the retriever for **unstructured** data with extractve answers "
]
},
{
@ -213,12 +221,30 @@
" print(doc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retriever for **structured** data with extractve answers "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"retriever = GoogleCloudEnterpriseSearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" search_engine_id=SEARCH_ENGINE_ID,\n",
" max_documents=3,\n",
" engine_data_type=1\n",
")\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
}
],
"metadata": {

View File

@ -69,6 +69,13 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
when making API calls. If not provided, credentials will be ascertained from
the environment."""
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the enterprise search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
@ -86,10 +93,18 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
from google.cloud import discoveryengine_v1beta # noqa: F401
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed. "
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
try:
from google.api_core.exceptions import InvalidArgument # noqa: F401
except ImportError as exc:
raise ImportError(
"google.api_core.exceptions is not installed. "
"Please install it with pip install google-api-core"
) from exc
values["project_id"] = get_from_dict_or_env(values, "project_id", "PROJECT_ID")
values["search_engine_id"] = get_from_dict_or_env(
values, "search_engine_id", "SEARCH_ENGINE_ID"
@ -110,7 +125,7 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
serving_config=self.serving_config_id,
)
def _convert_search_response(
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
@ -149,6 +164,30 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
return documents
def _convert_structured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
import json
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
documents.append(
Document(
page_content=json.dumps(document_dict.get("struct_data", {})),
metadata={"id": document_dict["id"], "name": document_dict["name"]},
)
)
return documents
def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object."""
from google.cloud.discoveryengine_v1beta import SearchRequest
@ -161,23 +200,32 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
mode=self.spell_correction_mode
)
if self.get_extractive_answers:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
if self.engine_data_type == 0:
if self.get_extractive_answers:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_answer_count=self.max_extractive_answer_count,
)
)
else:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_segment_count=self.max_extractive_segment_count,
)
)
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=extractive_content_spec
)
elif self.engine_data_type == 1:
content_search_spec = None
else:
extractive_content_spec = (
SearchRequest.ContentSearchSpec.ExtractiveContentSpec(
max_extractive_segment_count=self.max_extractive_segment_count,
)
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
+ f" Got {self.engine_data_type}"
)
content_search_spec = SearchRequest.ContentSearchSpec(
extractive_content_spec=extractive_content_spec,
)
return SearchRequest(
query=query,
filter=self.filter,
@ -192,8 +240,27 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.api_core.exceptions import InvalidArgument
search_request = self._create_search_request(query)
response = self._client.search(search_request)
documents = self._convert_search_response(response.results)
try:
response = self._client.search(search_request)
except InvalidArgument as e:
raise type(e)(
e.message + " This might be due to engine_data_type not set correctly."
)
if self.engine_data_type == 0:
documents = self._convert_unstructured_search_response(response.results)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
else:
# TODO: Add extra data type handling for type website
raise NotImplementedError(
"Only engine data type 0 (Unstructured) or 1 (Structured)"
+ " are supported currently."
+ f" Got {self.engine_data_type}"
)
return documents

View File

@ -3522,7 +3522,6 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
]
[[package]]
@ -8114,8 +8113,10 @@ description = "Fast and Safe Tensor serialization"
optional = true
python-versions = "*"
files = [
{file = "safetensors-0.3.2-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:4c7827b64b1da3f082301b5f5a34331b8313104c14f257099a12d32ac621c5cd"},
{file = "safetensors-0.3.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:b6a66989075c2891d743153e8ba9ca84ee7232c8539704488f454199b8b8f84d"},
{file = "safetensors-0.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:670d6bc3a3b377278ce2971fa7c36ebc0a35041c4ea23b9df750a39380800195"},
{file = "safetensors-0.3.2-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:67ef2cc747c88e3a8d8e4628d715874c0366a8ff1e66713a9d42285a429623ad"},
{file = "safetensors-0.3.2-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:564f42838721925b5313ae864ba6caa6f4c80a9fbe63cf24310c3be98ab013cd"},
{file = "safetensors-0.3.2-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:7f80af7e4ab3188daaff12d43d078da3017a90d732d38d7af4eb08b6ca2198a5"},
{file = "safetensors-0.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec30d78f20f1235b252d59cbb9755beb35a1fde8c24c89b3c98e6a1804cfd432"},
@ -8124,7 +8125,9 @@ files = [
{file = "safetensors-0.3.2-cp310-cp310-win32.whl", hash = "sha256:2961c1243fd0da46aa6a1c835305cc4595486f8ac64632a604d0eb5f2de76175"},
{file = "safetensors-0.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:c813920482c337d1424d306e1b05824a38e3ef94303748a0a287dea7a8c4f805"},
{file = "safetensors-0.3.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:707df34bd9b9047e97332136ad98e57028faeccdb9cfe1c3b52aba5964cc24bf"},
{file = "safetensors-0.3.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:23d1d9f74208c9dfdf852a9f986dac63e40092385f84bf0789d599efa8e6522f"},
{file = "safetensors-0.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:becc5bb85b2947eae20ed23b407ebfd5277d9a560f90381fe2c42e6c043677ba"},
{file = "safetensors-0.3.2-cp311-cp311-macosx_12_0_universal2.whl", hash = "sha256:c1913c6c549b1805e924f307159f0ee97b73ae3ce150cd2401964da015e0fa0b"},
{file = "safetensors-0.3.2-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:30a75707be5cc9686490bde14b9a371cede4af53244ea72b340cfbabfffdf58a"},
{file = "safetensors-0.3.2-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:54ad6af663e15e2b99e2ea3280981b7514485df72ba6d014dc22dae7ba6a5e6c"},
{file = "safetensors-0.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37764b3197656ef507a266c453e909a3477dabc795962b38e3ad28226f53153b"},
@ -8132,22 +8135,28 @@ files = [
{file = "safetensors-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ada0fac127ff8fb04834da5c6d85a8077e6a1c9180a11251d96f8068db922a17"},
{file = "safetensors-0.3.2-cp311-cp311-win32.whl", hash = "sha256:155b82dbe2b0ebff18cde3f76b42b6d9470296e92561ef1a282004d449fa2b4c"},
{file = "safetensors-0.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:a86428d196959619ce90197731be9391b5098b35100a7228ef4643957648f7f5"},
{file = "safetensors-0.3.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:91e796b6e465d9ffaca4c411d749f236c211e257f3a8e9b25a5ffc1a42d3bfa7"},
{file = "safetensors-0.3.2-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:c1f8ab41ed735c5b581f451fd15d9602ff51aa88044bfa933c5fa4b1d0c644d1"},
{file = "safetensors-0.3.2-cp37-cp37m-macosx_12_0_x86_64.whl", hash = "sha256:e6a8ff5652493598c45cd27f5613c193d3f15e76e0f81613d399c487a7b8cc50"},
{file = "safetensors-0.3.2-cp37-cp37m-macosx_13_0_x86_64.whl", hash = "sha256:bc9cfb3c9ea2aec89685b4d656f9f2296f0f0d67ecf2bebf950870e3be89b3db"},
{file = "safetensors-0.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ace5d471e3d78e0d93f952707d808b5ab5eac77ddb034ceb702e602e9acf2be9"},
{file = "safetensors-0.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de3e20a388b444381bcda1a3193cce51825ddca277e4cf3ed1fe8d9b2d5722cd"},
{file = "safetensors-0.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d7d70d48585fe8df00725aa788f2e64fd24a4c9ae07cd6be34f6859d0f89a9c"},
{file = "safetensors-0.3.2-cp37-cp37m-win32.whl", hash = "sha256:6ff59bc90cdc857f68b1023be9085fda6202bbe7f2fd67d06af8f976d6adcc10"},
{file = "safetensors-0.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:8b05c93da15fa911763a89281906ca333ed800ab0ef1c7ce53317aa1a2322f19"},
{file = "safetensors-0.3.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:94857abc019b49a22a0065cc7741c48fb788aa7d8f3f4690c092c56090227abe"},
{file = "safetensors-0.3.2-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:8969cfd9e8d904e8d3c67c989e1bd9a95e3cc8980d4f95e4dcd43c299bb94253"},
{file = "safetensors-0.3.2-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:da482fa011dc88fe7376d8f8b42c0ccef2f260e0cbc847ceca29c708bf75a868"},
{file = "safetensors-0.3.2-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:f54148ac027556eb02187e9bc1556c4d916c99ca3cb34ca36a7d304d675035c1"},
{file = "safetensors-0.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caec25fedbcf73f66c9261984f07885680f71417fc173f52279276c7f8a5edd3"},
{file = "safetensors-0.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:50224a1d99927ccf3b75e27c3d412f7043280431ab100b4f08aad470c37cf99a"},
{file = "safetensors-0.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa98f49e95f02eb750d32c4947e7d5aa43883149ebd0414920866446525b70f0"},
{file = "safetensors-0.3.2-cp38-cp38-win32.whl", hash = "sha256:33409df5e28a83dc5cc5547a3ac17c0f1b13a1847b1eb3bc4b3be0df9915171e"},
{file = "safetensors-0.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:e04a7cbbb3856159ab99e3adb14521544f65fcb8548cce773a1435a0f8d78d27"},
{file = "safetensors-0.3.2-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:f39f3d951543b594c6bc5082149d994c47ca487fd5d55b4ce065ab90441aa334"},
{file = "safetensors-0.3.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:7c864cf5dcbfb608c5378f83319c60cc9c97263343b57c02756b7613cd5ab4dd"},
{file = "safetensors-0.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:14e8c19d6dc51d4f70ee33c46aff04c8ba3f95812e74daf8036c24bc86e75cae"},
{file = "safetensors-0.3.2-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:41b10b0a6dfe8fdfbe4b911d64717d5647e87fbd7377b2eb3d03fb94b59810ea"},
{file = "safetensors-0.3.2-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:042a60f633c3c7009fdf6a7c182b165cb7283649d2a1e9c7a4a1c23454bd9a5b"},
{file = "safetensors-0.3.2-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:fafd95e5ef41e8f312e2a32b7031f7b9b2a621b255f867b221f94bb2e9f51ae8"},
{file = "safetensors-0.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ed77cf358abce2307f03634694e0b2a29822e322a1623e0b1aa4b41e871bf8b"},
@ -10324,4 +10333,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "0247674f3f274fd2249ceb02c23a468f911a7c482796ea67252b203d1ab938ae"
content-hash = "27c44e64d872c51f42b58f9f5185f20914dc4360e91860cfc260b1acbdaa3272"

View File

@ -125,6 +125,7 @@ newspaper3k = {version = "^0.2.8", optional = true}
amazon-textract-caller = {version = "<2", optional = true}
xata = {version = "^1.0.0a7", optional = true}
xmltodict = {version = "^0.13.0", optional = true}
google-api-core = {version = "^2.11.1", optional = true}
[tool.poetry.group.test.dependencies]

View File

@ -11,12 +11,15 @@ PROJECT_ID - set to your Google Cloud project ID
SEARCH_ENGINE_ID - the ID of the search engine to use for the test
"""
import pytest
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.schema import Document
@pytest.mark.requires("google_api_core")
def test_google_cloud_enterprise_search_get_relevant_documents() -> None:
"""Test the get_relevant_documents() method."""
retriever = GoogleCloudEnterpriseSearchRetriever()

View File

@ -23,7 +23,7 @@ def init_repo(tmpdir: py.path.local, dir_name: str) -> str:
git.add([sample_file])
git.commit(m="Initial commit")
return repo_dir
return str(repo_dir)
@pytest.mark.requires("git")