forked from Archives/langchain
e82687ddf4
Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com>
153 lines
5.5 KiB
Python
153 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import root_validator
|
|
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManagerForToolRun,
|
|
CallbackManagerForToolRun,
|
|
)
|
|
from langchain.tools.azure_cognitive_services.utils import detect_file_src_type
|
|
from langchain.tools.base import BaseTool
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AzureCogsFormRecognizerTool(BaseTool):
|
|
"""Tool that queries the Azure Cognitive Services Form Recognizer API.
|
|
|
|
In order to set this up, follow instructions at:
|
|
https://learn.microsoft.com/en-us/azure/applied-ai-services/form-recognizer/quickstarts/get-started-sdks-rest-api?view=form-recog-3.0.0&pivots=programming-language-python
|
|
"""
|
|
|
|
azure_cogs_key: str = "" #: :meta private:
|
|
azure_cogs_endpoint: str = "" #: :meta private:
|
|
doc_analysis_client: Any #: :meta private:
|
|
|
|
name = "azure_cognitive_services_form_recognizer"
|
|
description = (
|
|
"A wrapper around Azure Cognitive Services Form Recognizer. "
|
|
"Useful for when you need to "
|
|
"extract text, tables, and key-value pairs from documents. "
|
|
"Input should be a url to a document."
|
|
)
|
|
|
|
@root_validator(pre=True)
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and endpoint exists in environment."""
|
|
azure_cogs_key = get_from_dict_or_env(
|
|
values, "azure_cogs_key", "AZURE_COGS_KEY"
|
|
)
|
|
|
|
azure_cogs_endpoint = get_from_dict_or_env(
|
|
values, "azure_cogs_endpoint", "AZURE_COGS_ENDPOINT"
|
|
)
|
|
|
|
try:
|
|
from azure.ai.formrecognizer import DocumentAnalysisClient
|
|
from azure.core.credentials import AzureKeyCredential
|
|
|
|
values["doc_analysis_client"] = DocumentAnalysisClient(
|
|
endpoint=azure_cogs_endpoint,
|
|
credential=AzureKeyCredential(azure_cogs_key),
|
|
)
|
|
|
|
except ImportError:
|
|
raise ImportError(
|
|
"azure-ai-formrecognizer is not installed. "
|
|
"Run `pip install azure-ai-formrecognizer` to install."
|
|
)
|
|
|
|
return values
|
|
|
|
def _parse_tables(self, tables: List[Any]) -> List[Any]:
|
|
result = []
|
|
for table in tables:
|
|
rc, cc = table.row_count, table.column_count
|
|
_table = [["" for _ in range(cc)] for _ in range(rc)]
|
|
for cell in table.cells:
|
|
_table[cell.row_index][cell.column_index] = cell.content
|
|
result.append(_table)
|
|
return result
|
|
|
|
def _parse_kv_pairs(self, kv_pairs: List[Any]) -> List[Any]:
|
|
result = []
|
|
for kv_pair in kv_pairs:
|
|
key = kv_pair.key.content if kv_pair.key else ""
|
|
value = kv_pair.value.content if kv_pair.value else ""
|
|
result.append((key, value))
|
|
return result
|
|
|
|
def _document_analysis(self, document_path: str) -> Dict:
|
|
document_src_type = detect_file_src_type(document_path)
|
|
if document_src_type == "local":
|
|
with open(document_path, "rb") as document:
|
|
poller = self.doc_analysis_client.begin_analyze_document(
|
|
"prebuilt-document", document
|
|
)
|
|
elif document_src_type == "remote":
|
|
poller = self.doc_analysis_client.begin_analyze_document_from_url(
|
|
"prebuilt-document", document_path
|
|
)
|
|
else:
|
|
raise ValueError(f"Invalid document path: {document_path}")
|
|
|
|
result = poller.result()
|
|
res_dict = {}
|
|
|
|
if result.content is not None:
|
|
res_dict["content"] = result.content
|
|
|
|
if result.tables is not None:
|
|
res_dict["tables"] = self._parse_tables(result.tables)
|
|
|
|
if result.key_value_pairs is not None:
|
|
res_dict["key_value_pairs"] = self._parse_kv_pairs(result.key_value_pairs)
|
|
|
|
return res_dict
|
|
|
|
def _format_document_analysis_result(self, document_analysis_result: Dict) -> str:
|
|
formatted_result = []
|
|
if "content" in document_analysis_result:
|
|
formatted_result.append(
|
|
f"Content: {document_analysis_result['content']}".replace("\n", " ")
|
|
)
|
|
|
|
if "tables" in document_analysis_result:
|
|
for i, table in enumerate(document_analysis_result["tables"]):
|
|
formatted_result.append(f"Table {i}: {table}".replace("\n", " "))
|
|
|
|
if "key_value_pairs" in document_analysis_result:
|
|
for kv_pair in document_analysis_result["key_value_pairs"]:
|
|
formatted_result.append(
|
|
f"{kv_pair[0]}: {kv_pair[1]}".replace("\n", " ")
|
|
)
|
|
|
|
return "\n".join(formatted_result)
|
|
|
|
def _run(
|
|
self,
|
|
query: str,
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
"""Use the tool."""
|
|
try:
|
|
document_analysis_result = self._document_analysis(query)
|
|
if not document_analysis_result:
|
|
return "No good document analysis result was found"
|
|
|
|
return self._format_document_analysis_result(document_analysis_result)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Error while running AzureCogsFormRecognizerTool: {e}")
|
|
|
|
async def _arun(
|
|
self,
|
|
query: str,
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
"""Use the tool asynchronously."""
|
|
raise NotImplementedError("AzureCogsFormRecognizerTool does not support async")
|