diff --git a/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py b/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py new file mode 100644 index 0000000000..dbdf41b11b --- /dev/null +++ b/libs/community/langchain_community/tools/openai_dalle_image_generation/__init__.py @@ -0,0 +1,7 @@ +"""Tool to generate an image using DALLE OpenAI V1 SDK.""" + +from langchain_community.tools.openai_dalle_image_generation.tool import ( + OpenAIDALLEImageGenerationTool, +) + +__all__ = ["OpenAIDALLEImageGenerationTool"] diff --git a/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py new file mode 100644 index 0000000000..36374e887f --- /dev/null +++ b/libs/community/langchain_community/tools/openai_dalle_image_generation/tool.py @@ -0,0 +1,29 @@ +"""Tool for the OpenAI DALLE V1 Image Generation SDK.""" + +from typing import Optional + +from langchain_core.callbacks import CallbackManagerForToolRun +from langchain_core.tools import BaseTool + +from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper + + +class OpenAIDALLEImageGenerationTool(BaseTool): + """Tool that generates an image using OpenAI DALLE.""" + + name: str = "openai_dalle" + description: str = ( + "A wrapper around OpenAI DALLE Image Generation. " + "Useful for when you need to generate an image of" + "people, places, paintings, animals, or other subjects. " + "Input should be a text prompt to generate an image." + ) + api_wrapper: DallEAPIWrapper + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the OpenAI DALLE Image Generation tool.""" + return self.api_wrapper.run(query) diff --git a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py new file mode 100644 index 0000000000..7358fc3292 --- /dev/null +++ b/libs/community/tests/unit_tests/tools/openai_dalle_image_generation/test_image_generation.py @@ -0,0 +1,15 @@ +from unittest.mock import MagicMock + +from langchain_community.tools.openai_dalle_image_generation import ( + OpenAIDALLEImageGenerationTool, +) + + +def test_generate_image() -> None: + """Test OpenAI DALLE Image Generation.""" + mock_api_resource = MagicMock() + # bypass pydantic validation as openai is not a package dependency + tool = OpenAIDALLEImageGenerationTool.construct(api_wrapper=mock_api_resource) + tool_input = {"query": "parrot on a branch"} + result = tool.run(tool_input) + assert result.startswith("https://")