mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
ed58eeb9c5
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
238 lines
7.7 KiB
Python
238 lines
7.7 KiB
Python
"""Tool for the Nuclia Understanding API.
|
|
|
|
Installation:
|
|
|
|
```bash
|
|
pip install --upgrade protobuf
|
|
pip install nucliadb-protos
|
|
```
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import mimetypes
|
|
import os
|
|
from typing import Any, Dict, Optional, Type, Union
|
|
|
|
import requests
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForToolRun,
|
|
CallbackManagerForToolRun,
|
|
)
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
from langchain_core.tools import BaseTool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NUASchema(BaseModel):
|
|
"""Input for Nuclia Understanding API.
|
|
|
|
Attributes:
|
|
action: Action to perform. Either `push` or `pull`.
|
|
id: ID of the file to push or pull.
|
|
path: Path to the file to push (needed only for `push` action).
|
|
text: Text content to process (needed only for `push` action).
|
|
"""
|
|
|
|
action: str = Field(
|
|
...,
|
|
description="Action to perform. Either `push` or `pull`.",
|
|
)
|
|
id: str = Field(
|
|
...,
|
|
description="ID of the file to push or pull.",
|
|
)
|
|
path: Optional[str] = Field(
|
|
...,
|
|
description="Path to the file to push (needed only for `push` action).",
|
|
)
|
|
text: Optional[str] = Field(
|
|
...,
|
|
description="Text content to process (needed only for `push` action).",
|
|
)
|
|
|
|
|
|
class NucliaUnderstandingAPI(BaseTool):
|
|
"""Tool to process files with the Nuclia Understanding API."""
|
|
|
|
name: str = "nuclia_understanding_api"
|
|
description: str = (
|
|
"A wrapper around Nuclia Understanding API endpoints. "
|
|
"Useful for when you need to extract text from any kind of files. "
|
|
)
|
|
args_schema: Type[BaseModel] = NUASchema
|
|
_results: Dict[str, Any] = {}
|
|
_config: Dict[str, Any] = {}
|
|
|
|
def __init__(self, enable_ml: bool = False) -> None:
|
|
zone = os.environ.get("NUCLIA_ZONE", "europe-1")
|
|
self._config["BACKEND"] = f"https://{zone}.nuclia.cloud/api/v1"
|
|
key = os.environ.get("NUCLIA_NUA_KEY")
|
|
if not key:
|
|
raise ValueError("NUCLIA_NUA_KEY environment variable not set")
|
|
else:
|
|
self._config["NUA_KEY"] = key
|
|
self._config["enable_ml"] = enable_ml
|
|
super().__init__()
|
|
|
|
def _run(
|
|
self,
|
|
action: str,
|
|
id: str,
|
|
path: Optional[str],
|
|
text: Optional[str],
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
"""Use the tool."""
|
|
if action == "push":
|
|
self._check_params(path, text)
|
|
if path:
|
|
return self._pushFile(id, path)
|
|
if text:
|
|
return self._pushText(id, text)
|
|
elif action == "pull":
|
|
return self._pull(id)
|
|
return ""
|
|
|
|
async def _arun(
|
|
self,
|
|
action: str,
|
|
id: str,
|
|
path: Optional[str] = None,
|
|
text: Optional[str] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
"""Use the tool asynchronously."""
|
|
self._check_params(path, text)
|
|
if path:
|
|
self._pushFile(id, path)
|
|
if text:
|
|
self._pushText(id, text)
|
|
data = None
|
|
while True:
|
|
data = self._pull(id)
|
|
if data:
|
|
break
|
|
await asyncio.sleep(15)
|
|
return data
|
|
|
|
def _pushText(self, id: str, text: str) -> str:
|
|
field = {
|
|
"textfield": {"text": {"body": text, "format": 0}},
|
|
"processing_options": {"ml_text": self._config["enable_ml"]},
|
|
}
|
|
return self._pushField(id, field)
|
|
|
|
def _pushFile(self, id: str, content_path: str) -> str:
|
|
with open(content_path, "rb") as source_file:
|
|
response = requests.post(
|
|
self._config["BACKEND"] + "/processing/upload",
|
|
headers={
|
|
"content-type": mimetypes.guess_type(content_path)[0]
|
|
or "application/octet-stream",
|
|
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
|
|
},
|
|
data=source_file.read(),
|
|
)
|
|
if response.status_code != 200:
|
|
logger.info(
|
|
f"Error uploading {content_path}: "
|
|
f"{response.status_code} {response.text}"
|
|
)
|
|
return ""
|
|
else:
|
|
field = {
|
|
"filefield": {"file": f"{response.text}"},
|
|
"processing_options": {"ml_text": self._config["enable_ml"]},
|
|
}
|
|
return self._pushField(id, field)
|
|
|
|
def _pushField(self, id: str, field: Any) -> str:
|
|
logger.info(f"Pushing {id} in queue")
|
|
response = requests.post(
|
|
self._config["BACKEND"] + "/processing/push",
|
|
headers={
|
|
"content-type": "application/json",
|
|
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
|
|
},
|
|
json=field,
|
|
)
|
|
if response.status_code != 200:
|
|
logger.info(
|
|
f"Error pushing field {id}:" f"{response.status_code} {response.text}"
|
|
)
|
|
raise ValueError("Error pushing field")
|
|
else:
|
|
uuid = response.json()["uuid"]
|
|
logger.info(f"Field {id} pushed in queue, uuid: {uuid}")
|
|
self._results[id] = {"uuid": uuid, "status": "pending"}
|
|
return uuid
|
|
|
|
def _pull(self, id: str) -> str:
|
|
self._pull_queue()
|
|
result = self._results.get(id, None)
|
|
if not result:
|
|
logger.info(f"{id} not in queue")
|
|
return ""
|
|
elif result["status"] == "pending":
|
|
logger.info(f'Waiting for {result["uuid"]} to be processed')
|
|
return ""
|
|
else:
|
|
return result["data"]
|
|
|
|
def _pull_queue(self) -> None:
|
|
try:
|
|
from nucliadb_protos.writer_pb2 import BrokerMessage
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"nucliadb-protos is not installed. "
|
|
"Run `pip install nucliadb-protos` to install."
|
|
) from e
|
|
try:
|
|
from google.protobuf.json_format import MessageToJson
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Unable to import google.protobuf, please install with "
|
|
"`pip install protobuf`."
|
|
) from e
|
|
|
|
res = requests.get(
|
|
self._config["BACKEND"] + "/processing/pull",
|
|
headers={
|
|
"x-stf-nuakey": "Bearer " + self._config["NUA_KEY"],
|
|
},
|
|
).json()
|
|
if res["status"] == "empty":
|
|
logger.info("Queue empty")
|
|
elif res["status"] == "ok":
|
|
payload = res["payload"]
|
|
pb = BrokerMessage()
|
|
pb.ParseFromString(base64.b64decode(payload))
|
|
uuid = pb.uuid
|
|
logger.info(f"Pulled {uuid} from queue")
|
|
matching_id = self._find_matching_id(uuid)
|
|
if not matching_id:
|
|
logger.info(f"No matching id for {uuid}")
|
|
else:
|
|
self._results[matching_id]["status"] = "done"
|
|
data = MessageToJson(
|
|
pb,
|
|
preserving_proto_field_name=True,
|
|
including_default_value_fields=True,
|
|
)
|
|
self._results[matching_id]["data"] = data
|
|
|
|
def _find_matching_id(self, uuid: str) -> Union[str, None]:
|
|
for id, result in self._results.items():
|
|
if result["uuid"] == uuid:
|
|
return id
|
|
return None
|
|
|
|
def _check_params(self, path: Optional[str], text: Optional[str]) -> None:
|
|
if not path and not text:
|
|
raise ValueError("File path or text is required")
|
|
if path and text:
|
|
raise ValueError("Cannot process both file and text on a single run")
|