Retriever based on GCP DocAI Warehouse (#11400)

- **Description:** implements a retriever on top of DocAI Warehouse (to
interact with existing enterprise documents)
  https://cloud.google.com/document-ai-warehouse?hl=en
  - **Issue:** new functionality
 
@baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11707/head
Leonid Kuligin 12 months ago committed by GitHub
parent 629d9b78fa
commit 2aba9ab47e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -152,6 +152,23 @@ See a [usage example](/docs/integrations/retrievers/google_vertex_ai_search).
from langchain.retrievers import GoogleVertexAISearchRetriever
```
### Document AI Warehouse
> [Google Cloud Document AI Warehouse](https://cloud.google.com/document-ai-warehouse)
> allows enterprises to search, store, govern, and manage documents and their AI-extracted
> data and metadata in a single platform. Documents should be uploaded outside of Langchain,
>
```python
from langchain.retrievers import GoogleDocumentAIWarehouseRetriever
docai_wh_retriever = GoogleDocumentAIWarehouseRetriever(
project_number=...
)
query = ...
documents = docai_wh_retriever.get_relevant_documents(
query, user_ldap=...
)
```
## Tools
### Google Search

@ -28,6 +28,9 @@ from langchain.retrievers.contextual_compression import ContextualCompressionRet
from langchain.retrievers.docarray import DocArrayRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain.retrievers.google_cloud_documentai_warehouse import (
GoogleDocumentAIWarehouseRetriever,
)
from langchain.retrievers.google_cloud_enterprise_search import (
GoogleCloudEnterpriseSearchRetriever,
)
@ -74,6 +77,7 @@ __all__ = [
"ContextualCompressionRetriever",
"ChaindeskRetriever",
"ElasticSearchBM25Retriever",
"GoogleDocumentAIWarehouseRetriever",
"GoogleCloudEnterpriseSearchRetriever",
"GoogleVertexAISearchRetriever",
"KayAiRetriever",

@ -0,0 +1,118 @@
"""Retriever wrapper for Google Cloud Document AI Warehouse."""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.pydantic_v1 import root_validator
from langchain.schema import BaseRetriever
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from google.cloud.contentwarehouse_v1 import (
DocumentServiceClient,
RequestMetadata,
SearchDocumentsRequest,
)
from google.cloud.contentwarehouse_v1.services.document_service.pagers import (
SearchDocumentsPager,
)
class GoogleDocumentAIWarehouseRetriever(BaseRetriever):
"""A retriever based on Document AI Warehouse.
Documents should be created and documents should be uploaded
in a separate flow, and this retriever uses only Document AI
schema_id provided to search for revelant documents.
More info: https://cloud.google.com/document-ai-warehouse.
"""
location: str = "us"
"GCP location where DocAI Warehouse is placed."
project_number: str
"GCP project number, should contain digits only."
schema_id: Optional[str] = None
"DocAI Warehouse schema to queary against. If nothing is provided, all documents "
"in the project will be searched."
qa_size_limit: int = 5
"The limit on the number of documents returned."
client: "DocumentServiceClient" = None #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validates the environment."""
try: # noqa: F401
from google.cloud.contentwarehouse_v1 import (
DocumentServiceClient,
)
except ImportError as exc:
raise ImportError(
"google.cloud.contentwarehouse is not installed."
"Please install it with pip install google-cloud-contentwarehouse"
) from exc
values["project_number"] = get_from_dict_or_env(
values, "project_number", "PROJECT_NUMBER"
)
values["client"] = DocumentServiceClient()
return values
def _prepare_request_metadata(self, user_ldap: str) -> "RequestMetadata":
from google.cloud.contentwarehouse_v1 import RequestMetadata, UserInfo
user_info = UserInfo(id=f"user:{user_ldap}")
return RequestMetadata(user_info=user_info)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
request = self._prepare_search_request(query, **kwargs)
response = self.client.search_documents(request=request)
return self._parse_search_response(response=response)
def _prepare_search_request(
self, query: str, **kwargs: Any
) -> "SearchDocumentsRequest":
from google.cloud.contentwarehouse_v1 import (
DocumentQuery,
SearchDocumentsRequest,
)
try:
user_ldap = kwargs["user_ldap"]
except KeyError:
raise ValueError("Argument user_ldap should be provided!")
request_metadata = self._prepare_request_metadata(user_ldap=user_ldap)
schemas = []
if self.schema_id:
schemas.append(
self.client.document_schema_path(
project=self.project_number,
location=self.location,
document_schema=self.schema_id,
)
)
return SearchDocumentsRequest(
parent=self.client.common_location_path(self.project_number, self.location),
request_metadata=request_metadata,
document_query=DocumentQuery(
query=query, is_nl_query=True, document_schema_names=schemas
),
qa_size_limit=self.qa_size_limit,
)
def _parse_search_response(
self, response: "SearchDocumentsPager"
) -> List[Document]:
documents = []
for doc in response.matching_documents:
metadata = {
"title": doc.document.title,
"source": doc.document.raw_document_path,
}
documents.append(
Document(page_content=doc.search_text_snippet, metadata=metadata)
)
return documents

@ -0,0 +1,25 @@
"""Test Google Cloud Document AI Warehouse retriever."""
import os
from langchain.retrievers import GoogleDocumentAIWarehouseRetriever
from langchain.schema import Document
def test_google_documentai_warehoure_retriever() -> None:
"""In order to run this test, you should provide a project_id and user_ldap.
Example:
export USER_LDAP=...
export PROJECT_NUMBER=...
"""
project_number = os.environ["PROJECT_NUMBER"]
user_ldap = os.environ["USER_LDAP"]
docai_wh_retriever = GoogleDocumentAIWarehouseRetriever(
project_number=project_number
)
documents = docai_wh_retriever.get_relevant_documents(
"What are Alphabet's Other Bets?", user_ldap=user_ldap
)
assert len(documents) > 0
for doc in documents:
assert isinstance(doc, Document)
Loading…
Cancel
Save