From d969f43ed8b1abe352a8b75ff33f01d767ce84a5 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Thu, 11 May 2023 00:07:36 -0700 Subject: [PATCH] Load HuggingFace Tool (#4475) # Add option to `load_huggingface_tool` Expose a method to load a huggingface Tool from the HF hub --------- Co-authored-by: Dev 2049 --- .../tools/examples/huggingface_tools.ipynb | 102 ++++++++++++++++++ langchain/agents/__init__.py | 7 +- langchain/agents/load_tools.py | 34 ++++++ tests/unit_tests/agents/test_public_api.py | 1 + 4 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 docs/modules/agents/tools/examples/huggingface_tools.ipynb diff --git a/docs/modules/agents/tools/examples/huggingface_tools.ipynb b/docs/modules/agents/tools/examples/huggingface_tools.ipynb new file mode 100644 index 00000000..5fbe8d91 --- /dev/null +++ b/docs/modules/agents/tools/examples/huggingface_tools.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "40a27d3c-4e5c-4b96-b290-4c49d4fd7219", + "metadata": {}, + "source": [ + "## HuggingFace Tools\n", + "\n", + "[Huggingface Tools](https://huggingface.co/docs/transformers/v4.29.0/en/custom_tools) supporting text I/O can be\n", + "loaded directly using the `load_huggingface_tool` function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1055b75-362c-452a-b40d-c9a359706a3a", + "metadata": {}, + "outputs": [], + "source": [ + "# Requires transformers>=4.29.0 and huggingface_hub>=0.14.1\n", + "!pip install --uprade transformers huggingface_hub > /dev/null" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f964bb45-fba3-4919-b022-70a602ed4354", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model_download_counter: This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. It takes the name of the category (such as text-classification, depth-estimation, etc), and returns the name of the checkpoint\n" + ] + } + ], + "source": [ + "from langchain.agents import load_huggingface_tool\n", + "\n", + "tool = load_huggingface_tool(\"lysandre/hf-model-downloads\")\n", + "\n", + "print(f\"{tool.name}: {tool.description}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "641d9d79-95bb-469d-b40a-50f37375de7f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'facebook/bart-large-mnli'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tool.run(\"text-classification\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88724222-7c10-4aff-8713-751911dc8b63", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/agents/__init__.py b/langchain/agents/__init__.py index b70dea94..86af75ff 100644 --- a/langchain/agents/__init__.py +++ b/langchain/agents/__init__.py @@ -23,7 +23,11 @@ from langchain.agents.agent_types import AgentType from langchain.agents.conversational.base import ConversationalAgent from langchain.agents.conversational_chat.base import ConversationalChatAgent from langchain.agents.initialize import initialize_agent -from langchain.agents.load_tools import get_all_tool_names, load_tools +from langchain.agents.load_tools import ( + get_all_tool_names, + load_huggingface_tool, + load_tools, +) from langchain.agents.loading import load_agent from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent from langchain.agents.react.base import ReActChain, ReActTextWorldAgent @@ -61,6 +65,7 @@ __all__ = [ "get_all_tool_names", "initialize_agent", "load_agent", + "load_huggingface_tool", "load_tools", "tool", ] diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index e7408411..e22db4c6 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -296,6 +296,40 @@ def _handle_callbacks( return callbacks +def load_huggingface_tool( + task_or_repo_id: str, + model_repo_id: Optional[str] = None, + token: Optional[str] = None, + remote: bool = False, + **kwargs: Any, +) -> BaseTool: + try: + from transformers import load_tool + except ImportError: + raise ValueError( + "HuggingFace tools require the libraries `transformers>=4.29.0`" + " and `huggingface_hub>=0.14.1` to be installed." + " Please install it with" + " `pip install --upgrade transformers huggingface_hub`." + ) + hf_tool = load_tool( + task_or_repo_id, + model_repo_id=model_repo_id, + token=token, + remote=remote, + **kwargs, + ) + outputs = hf_tool.outputs + if set(outputs) != {"text"}: + raise NotImplementedError("Multimodal outputs not supported yet.") + inputs = hf_tool.inputs + if set(inputs) != {"text"}: + raise NotImplementedError("Multimodal inputs not supported yet.") + return Tool.from_function( + hf_tool.__call__, name=hf_tool.name, description=hf_tool.description + ) + + def load_tools( tool_names: List[str], llm: Optional[BaseLanguageModel] = None, diff --git a/tests/unit_tests/agents/test_public_api.py b/tests/unit_tests/agents/test_public_api.py index 05055a93..0b6f9df9 100644 --- a/tests/unit_tests/agents/test_public_api.py +++ b/tests/unit_tests/agents/test_public_api.py @@ -30,6 +30,7 @@ _EXPECTED = [ "get_all_tool_names", "initialize_agent", "load_agent", + "load_huggingface_tool", "load_tools", "tool", ]