From 4c3a67122f63b89dd9a89f78884f4c22489c357f Mon Sep 17 00:00:00 2001 From: Paulo Nascimento <37284051+paulonasc@users.noreply.github.com> Date: Thu, 28 Mar 2024 17:23:14 -0700 Subject: [PATCH] community[patch]: add Integration for OpenAI image gen with v1 sdk (#17771) **Description:** Created a Langchain Tool for OpenAI DALLE Image Generation. **Issue:** [#15901](https://github.com/langchain-ai/langchain/issues/15901) **Dependencies:** n/a **Twitter handle:** @paulodoestech - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Bagatur --- .../openai_dalle_image_generation/__init__.py | 7 +++++ .../openai_dalle_image_generation/tool.py | 29 +++++++++++++++++++ .../openai_dalle_image_generation/__init__.py | 0 .../test_image_generation.py | 15 ++++++++++ 4 files changed, 51 insertions(+) create mode 100644 libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py create mode 100644 libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py create mode 100644 libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py create mode 100644 libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py diff --git a/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py b/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py new file mode 100644 index 0000000000..dbdf41b11b --- /dev/null +++ b/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py @@ -0,0 +1,7 @@ +"""Tool to generate an image using DALLE OpenAI V1 SDK.""" + +from langchain_community.tools.openai_dalle_image_generation.tool import ( + OpenAIDALLEImageGenerationTool, +) + +__all__ = ["OpenAIDALLEImageGenerationTool"] diff --git a/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py new file mode 100644 index 0000000000..36374e887f --- /dev/null +++ b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py @@ -0,0 +1,29 @@ +"""Tool for the OpenAI DALLE V1 Image Generation SDK.""" + +from typing import Optional + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.tools import BaseTool + +from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper + + +class OpenAIDALLEImageGenerationTool(BaseTool): + """Tool that generates an image using OpenAI DALLE.""" + + name: str = "openai_dalle" + description: str = ( + "A wrapper around OpenAI DALLE Image Generation. " + "Useful for when you need to generate an image of" + "people, places, paintings, animals, or other subjects. " + "Input should be a text prompt to generate an image." + ) + api_wrapper: DallEAPIWrapper + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the OpenAI DALLE Image Generation tool.""" + return self.api_wrapper.run(query) diff --git a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py new file mode 100644 index 0000000000..7358fc3292 --- /dev/null +++ b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py @@ -0,0 +1,15 @@ +from unittest.mock import MagicMock + +from langchain_community.tools.openai_dalle_image_generation import ( + OpenAIDALLEImageGenerationTool, +) + + +def test_generate_image() -> None: + """Test OpenAI DALLE Image Generation.""" + mock_api_resource = MagicMock() + # bypass pydantic validation as openai is not a package dependency + tool = OpenAIDALLEImageGenerationTool.construct(api_wrapper=mock_api_resource) + tool_input = {"query": "parrot on a branch"} + result = tool.run(tool_input) + assert result.startswith("https://")