@ -4,88 +4,32 @@ from __future__ import annotations
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Sequence
from langchain . callbacks . manager import CallbackManagerForRetrieverRun
from langchain . pydantic_v1 import Extra, Field , root_validator
from langchain . pydantic_v1 import BaseModel, Extra, Field , root_validator
from langchain . schema import BaseRetriever , Document
from langchain . utils import get_from_dict_or_env
if TYPE_CHECKING :
from google . api_core . client_options import ClientOptions
from google . cloud . discoveryengine_v1beta import (
ConversationalSearchServiceClient ,
SearchRequest ,
SearchResult ,
SearchServiceClient ,
)
class GoogleVertexAISearchRetriever ( BaseRetriever ) :
""" `Google Vertex AI Search` retriever.
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters , refer to the product documentation .
https : / / cloud . google . com / generative - ai - app - builder / docs / enterprise - search - introduction
"""
class _BaseGoogleVertexAISearchRetriever ( BaseModel ) :
project_id : str
""" Google Cloud Project ID. """
data_store_id : str
""" Vertex AI Search data store ID. """
serving_config_id : str = " default_config "
""" Vertex AI Search serving config ID. """
location_id : str = " global "
""" Vertex AI Search data store location. """
filter : Optional [ str ] = None
""" Filter expression. """
get_extractive_answers : bool = False
""" If True return Extractive Answers, otherwise return Extractive Segments. """
max_documents : int = Field ( default = 5 , ge = 1 , le = 100 )
""" The maximum number of documents to return. """
max_extractive_answer_count : int = Field ( default = 1 , ge = 1 , le = 5 )
""" The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult .
"""
max_extractive_segment_count : int = Field ( default = 1 , ge = 1 , le = 1 )
""" The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult .
"""
query_expansion_condition : int = Field ( default = 1 , ge = 0 , le = 2 )
""" Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition . In this case , server behavior defaults
to disabled
1 - Disabled query expansion . Only the exact search query is used , even if
SearchResponse . total_size is zero .
2 - Automatic query expansion built by the Search API .
"""
spell_correction_mode : int = Field ( default = 2 , ge = 0 , le = 2 )
""" Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode . In this case , server behavior defaults
to auto .
1 - Suggestion only . Search API will try to find a spell suggestion if there is any
and put in the ` SearchResponse . corrected_query ` .
The spell suggestion will not be used as the search query .
2 - Automatic spell correction built by the Search API .
Search will be based on the corrected query if found .
"""
credentials : Any = None
""" The default custom credentials (google.auth.credentials.Credentials) to use
when making API calls . If not provided , credentials will be ascertained from
the environment . """
# TODO: Add extra data type handling for type website
engine_data_type : int = Field ( default = 0 , ge = 0 , le = 1 )
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
_client : SearchServiceClient
_serving_config : str
class Config :
""" Configuration for this pydantic object. """
extra = Extra . ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
@root_validator ( pre = True )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validates the environment. """
@ -94,9 +38,9 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
except ImportError as exc :
raise ImportError (
" google.cloud.discoveryengine is not installed. "
" Please install it with pip install google-cloud-discoveryengine "
" Please install it with pip install "
" google-cloud-discoveryengine>=0.11.0 "
) from exc
try :
from google . api_core . exceptions import InvalidArgument # noqa: F401
except ImportError as exc :
@ -130,47 +74,42 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
return values
def __init__ ( self , * * data : Any ) - > None :
""" Initializes private fields. """
try :
from google . cloud . discoveryengine_v1beta import SearchServiceClient
except ImportError as exc :
raise ImportError (
" google.cloud.discoveryengine is not installed. "
" Please install it with pip install google-cloud-discoveryengine "
) from exc
try :
from google . api_core . client_options import ClientOptions
except ImportError as exc :
raise ImportError (
" google.api_core.client_options is not installed. "
" Please install it with pip install google-api-core "
) from exc
super ( ) . __init__ ( * * data )
@property
def client_options ( self ) - > " ClientOptions " :
from google . api_core . client_options import ClientOptions
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
api_endpoint = (
" discoveryengine.googleapis.com "
if self . location_id == " global "
else f " { self . location_id } -discoveryengine.googleapis.com "
return ClientOptions (
api_endpoint = f " { self . location_id } -discoveryengine.googleapis.com "
if self . location_id != " global "
else None
)
self . _client = SearchServiceClient (
credentials = self . credentials ,
client_options = ClientOptions ( api_endpoint = api_endpoint ) ,
)
def _convert_structured_search_response (
self , results : Sequence [ SearchResult ]
) - > List [ Document ] :
""" Converts a sequence of search results to a list of LangChain documents. """
import json
self . _serving_config = self . _client . serving_config_path (
project = self . project_id ,
location = self . location_id ,
data_store = self . data_store_id ,
serving_config = self . serving_config_id ,
)
from google . protobuf . json_format import MessageToDict
documents : List [ Document ] = [ ]
for result in results :
document_dict = MessageToDict (
result . document . _pb , preserving_proto_field_name = True
)
documents . append (
Document (
page_content = json . dumps ( document_dict . get ( " struct_data " , { } ) ) ,
metadata = { " id " : document_dict [ " id " ] , " name " : document_dict [ " name " ] } ,
)
)
return documents
def _convert_unstructured_search_response (
self , results : Sequence [ SearchResult ]
self , results : Sequence [ SearchResult ] , chunk_type : str
) - > List [ Document ] :
""" Converts a sequence of search results to a list of LangChain documents. """
from google . protobuf . json_format import MessageToDict
@ -188,12 +127,6 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
doc_metadata = document_dict . get ( " struct_data " , { } )
doc_metadata [ " id " ] = document_dict [ " id " ]
chunk_type = (
" extractive_answers "
if self . get_extractive_answers
else " extractive_segments "
)
if chunk_type not in derived_struct_data :
continue
@ -211,29 +144,91 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
return documents
def _convert_structured_search_response (
self , results : Sequence [ SearchResult ]
) - > List [ Document ] :
""" Converts a sequence of search results to a list of LangChain documents. """
import json
from google . protobuf . json_format import MessageToDict
class GoogleVertexAISearchRetriever ( BaseRetriever , _BaseGoogleVertexAISearchRetriever ) :
""" `Google Vertex AI Search` retriever.
documents : List [ Document ] = [ ]
For a detailed explanation of the Vertex AI Search concepts
and configuration parameters , refer to the product documentation .
https : / / cloud . google . com / generative - ai - app - builder / docs / enterprise - search - introduction
"""
for result in results :
document_dict = MessageToDict (
result . document . _pb , preserving_proto_field_name = True
)
serving_config_id : str = " default_config "
""" Vertex AI Search serving config ID. """
filter : Optional [ str ] = None
""" Filter expression. """
get_extractive_answers : bool = False
""" If True return Extractive Answers, otherwise return Extractive Segments. """
max_documents : int = Field ( default = 5 , ge = 1 , le = 100 )
""" The maximum number of documents to return. """
max_extractive_answer_count : int = Field ( default = 1 , ge = 1 , le = 5 )
""" The maximum number of extractive answers returned in each search result.
At most 5 answers will be returned for each SearchResult .
"""
max_extractive_segment_count : int = Field ( default = 1 , ge = 1 , le = 1 )
""" The maximum number of extractive segments returned in each search result.
Currently one segment will be returned for each SearchResult .
"""
query_expansion_condition : int = Field ( default = 1 , ge = 0 , le = 2 )
""" Specification to determine under which conditions query expansion should occur.
0 - Unspecified query expansion condition . In this case , server behavior defaults
to disabled
1 - Disabled query expansion . Only the exact search query is used , even if
SearchResponse . total_size is zero .
2 - Automatic query expansion built by the Search API .
"""
spell_correction_mode : int = Field ( default = 2 , ge = 0 , le = 2 )
""" Specification to determine under which conditions query expansion should occur.
0 - Unspecified spell correction mode . In this case , server behavior defaults
to auto .
1 - Suggestion only . Search API will try to find a spell suggestion if there is any
and put in the ` SearchResponse . corrected_query ` .
The spell suggestion will not be used as the search query .
2 - Automatic spell correction built by the Search API .
Search will be based on the corrected query if found .
"""
documents . append (
Document (
page_content = json . dumps ( document_dict . get ( " struct_data " , { } ) ) ,
metadata = { " id " : document_dict [ " id " ] , " name " : document_dict [ " name " ] } ,
)
)
# TODO: Add extra data type handling for type website
engine_data_type : int = Field ( default = 0 , ge = 0 , le = 1 )
""" Defines the Vertex AI Search data type
0 - Unstructured data
1 - Structured data
"""
return documents
_client : SearchServiceClient
_serving_config : str
class Config :
""" Configuration for this pydantic object. """
extra = Extra . ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__ ( self , * * kwargs : Any ) - > None :
""" Initializes private fields. """
try :
from google . cloud . discoveryengine_v1beta import SearchServiceClient
except ImportError as exc :
raise ImportError (
" google.cloud.discoveryengine is not installed. "
" Please install it with pip install google-cloud-discoveryengine "
) from exc
super ( ) . __init__ ( * * kwargs )
# For more information, refer to:
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
self . _client = SearchServiceClient (
credentials = self . credentials , client_options = self . client_options
)
self . _serving_config = self . _client . serving_config_path (
project = self . project_id ,
location = self . location_id ,
data_store = self . data_store_id ,
serving_config = self . serving_config_id ,
)
def _create_search_request ( self , query : str ) - > SearchRequest :
""" Prepares a SearchRequest object. """
@ -300,7 +295,14 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
)
if self . engine_data_type == 0 :
documents = self . _convert_unstructured_search_response ( response . results )
chunk_type = (
" extractive_answers "
if self . get_extractive_answers
else " extractive_segments "
)
documents = self . _convert_unstructured_search_response (
response . results , chunk_type
)
elif self . engine_data_type == 1 :
documents = self . _convert_structured_search_response ( response . results )
else :
@ -312,3 +314,46 @@ class GoogleVertexAISearchRetriever(BaseRetriever):
)
return documents
class GoogleVertexAIMultiTurnSearchRetriever (
BaseRetriever , _BaseGoogleVertexAISearchRetriever
) :
_client : ConversationalSearchServiceClient
class Config :
""" Configuration for this pydantic object. """
extra = Extra . ignore
arbitrary_types_allowed = True
underscore_attrs_are_private = True
def __init__ ( self , * * kwargs : Any ) :
super ( ) . __init__ ( * * kwargs )
from google . cloud . discoveryengine_v1beta import (
ConversationalSearchServiceClient ,
)
self . _client = ConversationalSearchServiceClient (
credentials = self . credentials , client_options = self . client_options
)
def _get_relevant_documents (
self , query : str , * , run_manager : CallbackManagerForRetrieverRun
) - > List [ Document ] :
""" Get documents relevant for a query. """
from google . cloud . discoveryengine_v1beta import (
ConverseConversationRequest ,
TextInput ,
)
request = ConverseConversationRequest (
name = self . _client . conversation_path (
self . project_id , self . location_id , self . data_store_id , " - "
) ,
query = TextInput ( input = query ) ,
)
response = self . _client . converse_conversation ( request )
return self . _convert_unstructured_search_response (
response . search_results , " extractive_answers "
)