added a multiturn search based on Vertex AI Search (#11885)

Replace this entire comment with:
- **Description:** Added a retriever based on multi-turn Vertex AI
Search
  - **Twitter handle:** lkuligin
pull/11764/head^2
Leonid Kuligin 12 months ago committed by GitHub
parent 38ed55245f
commit d269dd2e2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -161,7 +161,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import GoogleVertexAISearchRetriever\n",
"from langchain.retrievers import GoogleVertexAISearchRetriever, GoogleVertexAIMultiTurnSearchRetriever\n",
"\n",
"PROJECT_ID = \"<YOUR PROJECT ID>\" # Set to your Project ID\n",
"LOCATION_ID = \"<YOUR LOCATION>\" # Set to your data store location\n",
@ -247,6 +247,37 @@
"for doc in result:\n",
" print(doc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure and use the retrieve for multi-turn search"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Search with follow-ups is [based](https://cloud.google.com/generative-ai-app-builder/docs/multi-turn-search) on generative AI models and it is different from the regular unstructured data search."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = GoogleVertexAIMultiTurnSearchRetriever(\n",
" project_id=PROJECT_ID,\n",
" location_id=LOCATION_ID,\n",
" data_store_id=DATA_STORE_ID\n",
")\n",
"\n",
"result = retriever.get_relevant_documents(query)\n",
"for doc in result:\n",
" print(doc)"
]
}
],
"metadata": {

@ -35,6 +35,7 @@ from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import (
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)
from langchain.retrievers.kay import KayAiRetriever
@ -79,6 +80,7 @@ __all__ = [
"ElasticSearchBM25Retriever",
"GoogleDocumentAIWarehouseRetriever",
"GoogleCloudEnterpriseSearchRetriever",
"GoogleVertexAIMultiTurnSearchRetriever",
"GoogleVertexAISearchRetriever",
"KayAiRetriever",
"KNNRetriever",

@ -4,88 +4,32 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.pydantic_v1 import Extra, Field, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from google.api_core.client_options import ClientOptions
from google.cloud.discoveryengine_v1beta import (
ConversationalSearchServiceClient,
SearchRequest,
SearchResult,
SearchServiceClient,
)
class GoogleVertexAISearchRetriever(BaseRetriever):
"""`Google Vertex AI Search` retriever.
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters, refer to the product documentation.
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
class _BaseGoogleVertexAISearchRetriever(BaseModel):
project_id: str
"""Google Cloud Project ID."""
data_store_id: str
"""Vertex AI Search data store ID."""
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
location_id: str = "global"
"""Vertex AI Search data store location."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
"""If True return Extractive Answers, otherwise return Extractive Segments."""
max_documents: int = Field(default=5, ge=1, le=100)
"""The maximum number of documents to return."""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult.
"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult.
"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition. In this case, server behavior defaults
to disabled
1 - Disabled query expansion. Only the exact search query is used, even if
SearchResponse.total_size is zero.
2 - Automatic query expansion built by the Search API.
"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode. In this case, server behavior defaults
to auto.
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
and put in the `SearchResponse.corrected_query`.
The spell suggestion will not be used as the search query.
2 - Automatic spell correction built by the Search API.
Search will be based on the corrected query if found.
"""
credentials: Any = None
"""The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls. If not provided, credentials will be ascertained from
the environment."""
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validates the environment."""
@ -94,9 +38,9 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
"Please install it with pip install "
"google-cloud-discoveryengine>=0.11.0"
) from exc
try:
from google.api_core.exceptions import InvalidArgument # noqa: F401
except ImportError as exc:
@ -130,47 +74,42 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
return values
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
try:
from google.api_core.client_options import ClientOptions
except ImportError as exc:
raise ImportError(
"google.api_core.client_options is not installed."
"Please install it with pip install google-api-core"
) from exc
super().__init__(**data)
@property
def client_options(self) -> "ClientOptions":
from google.api_core.client_options import ClientOptions
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
api_endpoint = (
"discoveryengine.googleapis.com"
if self.location_id == "global"
else f"{self.location_id}-discoveryengine.googleapis.com"
return ClientOptions(
api_endpoint=f"{self.location_id}-discoveryengine.googleapis.com"
if self.location_id != "global"
else None
)
self._client = SearchServiceClient(
credentials=self.credentials,
client_options=ClientOptions(api_endpoint=api_endpoint),
)
def _convert_structured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
import json
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
from google.protobuf.json_format import MessageToDict
documents: List[Document] = []
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
documents.append(
Document(
page_content=json.dumps(document_dict.get("struct_data", {})),
metadata={"id": document_dict["id"], "name": document_dict["name"]},
)
)
return documents
def _convert_unstructured_search_response(
self, results: Sequence[SearchResult]
self, results: Sequence[SearchResult], chunk_type: str
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
from google.protobuf.json_format import MessageToDict
@ -188,12 +127,6 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
doc_metadata = document_dict.get("struct_data", {})
doc_metadata["id"] = document_dict["id"]
chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
if chunk_type not in derived_struct_data:
continue
@ -211,29 +144,91 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
return documents
def _convert_structured_search_response(
self, results: Sequence[SearchResult]
) -> List[Document]:
"""Converts a sequence of search results to a list of LangChain documents."""
import json
from google.protobuf.json_format import MessageToDict
class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetriever):
"""`Google Vertex AI Search` retriever.
documents: List[Document] = []
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters, refer to the product documentation.
https://cloud.google.com/generative-ai-app-builder/docs/enterprise-search-introduction
"""
for result in results:
document_dict = MessageToDict(
result.document._pb, preserving_proto_field_name=True
)
serving_config_id: str = "default_config"
"""Vertex AI Search serving config ID."""
filter: Optional[str] = None
"""Filter expression."""
get_extractive_answers: bool = False
"""If True return Extractive Answers, otherwise return Extractive Segments."""
max_documents: int = Field(default=5, ge=1, le=100)
"""The maximum number of documents to return."""
max_extractive_answer_count: int = Field(default=1, ge=1, le=5)
"""The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult.
"""
max_extractive_segment_count: int = Field(default=1, ge=1, le=1)
"""The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult.
"""
query_expansion_condition: int = Field(default=1, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition. In this case, server behavior defaults
to disabled
1 - Disabled query expansion. Only the exact search query is used, even if
SearchResponse.total_size is zero.
2 - Automatic query expansion built by the Search API.
"""
spell_correction_mode: int = Field(default=2, ge=0, le=2)
"""Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode. In this case, server behavior defaults
to auto.
1 - Suggestion only. Search API will try to find a spell suggestion if there is any
and put in the `SearchResponse.corrected_query`.
The spell suggestion will not be used as the search query.
2 - Automatic spell correction built by the Search API.
Search will be based on the corrected query if found.
"""
documents.append(
Document(
page_content=json.dumps(document_dict.get("struct_data", {})),
metadata={"id": document_dict["id"], "name": document_dict["name"]},
)
)
# TODO: Add extra data type handling for type website
engine_data_type: int = Field(default=0, ge=0, le=1)
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
return documents
_client: SearchServiceClient
_serving_config: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__(self, **kwargs: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
except ImportError as exc:
raise ImportError(
"google.cloud.discoveryengine is not installed."
"Please install it with pip install google-cloud-discoveryengine"
) from exc
super().__init__(**kwargs)
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
self._client = SearchServiceClient(
credentials=self.credentials, client_options=self.client_options
)
self._serving_config = self._client.serving_config_path(
project=self.project_id,
location=self.location_id,
data_store=self.data_store_id,
serving_config=self.serving_config_id,
)
def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object."""
@ -300,7 +295,14 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
)
if self.engine_data_type == 0:
documents = self._convert_unstructured_search_response(response.results)
chunk_type = (
"extractive_answers"
if self.get_extractive_answers
else "extractive_segments"
)
documents = self._convert_unstructured_search_response(
response.results, chunk_type
)
elif self.engine_data_type == 1:
documents = self._convert_structured_search_response(response.results)
else:
@ -312,3 +314,46 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
)
return documents
class GoogleVertexAIMultiTurnSearchRetriever(
BaseRetriever, _BaseGoogleVertexAISearchRetriever
):
_client: ConversationalSearchServiceClient
class Config:
"""Configuration for this pydantic object."""
extra = Extra.ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from google.cloud.discoveryengine_v1beta import (
ConversationalSearchServiceClient,
)
self._client = ConversationalSearchServiceClient(
credentials=self.credentials, client_options=self.client_options
)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.cloud.discoveryengine_v1beta import (
ConverseConversationRequest,
TextInput,
)
request = ConverseConversationRequest(
name=self._client.conversation_path(
self.project_id, self.location_id, self.data_store_id, "-"
),
query=TextInput(input=query),
)
response = self._client.converse_conversation(request)
return self._convert_unstructured_search_response(
response.search_results, "extractive_answers"
)

@ -7,8 +7,8 @@ google_vertex_ai_search.ipynb
to set up the app and configure authentication.
Set the following environment variables before the tests:
PROJECT_ID - set to your Google Cloud project ID
DATA_STORE_ID - the ID of the search engine to use for the test
export PROJECT_ID=... - set to your Google Cloud project ID
export DATA_STORE_ID=... - the ID of the search engine to use for the test
"""
import os
@ -18,7 +18,10 @@ import pytest
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
from langchain.retrievers.google_vertex_ai_search import GoogleVertexAISearchRetriever
from langchain.retrievers.google_vertex_ai_search import (
GoogleVertexAIMultiTurnSearchRetriever,
GoogleVertexAISearchRetriever,
)
from langchain.schema import Document
@ -35,6 +38,19 @@ def test_google_vertex_ai_search_get_relevant_documents() -> None:
assert doc.metadata["source"]
@pytest.mark.requires("google_api_core")
def test_google_vertex_ai_multiturnsearch_get_relevant_documents() -> None:
"""Test the get_relevant_documents() method."""
retriever = GoogleVertexAIMultiTurnSearchRetriever()
documents = retriever.get_relevant_documents("What are Alphabet's Other Bets?")
assert len(documents) > 0
for doc in documents:
assert isinstance(doc, Document)
assert doc.page_content
assert doc.metadata["id"]
assert doc.metadata["source"]
@pytest.mark.requires("google_api_core")
def test_google_vertex_ai_search_enterprise_search_deprecation() -> None:
"""Test the deprecation of GoogleCloudEnterpriseSearchRetriever."""

Loading…
Cancel
Save