mirror of https://github.com/hwchase17/langchain
community[patch]: add Integration for OpenAI image gen with v1 sdk (#17771)
**Description:** Created a Langchain Tool for OpenAI DALLE Image Generation. **Issue:** [#15901](https://github.com/langchain-ai/langchain/issues/15901) **Dependencies:** n/a **Twitter handle:** @paulodoestech - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/17721/head^2
parent
a8104ea8e9
commit
4c3a67122f
@ -0,0 +1,7 @@
|
||||
"""Tool to generate an image using DALLE OpenAI V1 SDK."""
|
||||
|
||||
from langchain_community.tools.openai_dalle_image_generation.tool import (
|
||||
OpenAIDALLEImageGenerationTool,
|
||||
)
|
||||
|
||||
__all__ = ["OpenAIDALLEImageGenerationTool"]
|
@ -0,0 +1,29 @@
|
||||
"""Tool for the OpenAI DALLE V1 Image Generation SDK."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper
|
||||
|
||||
|
||||
class OpenAIDALLEImageGenerationTool(BaseTool):
|
||||
"""Tool that generates an image using OpenAI DALLE."""
|
||||
|
||||
name: str = "openai_dalle"
|
||||
description: str = (
|
||||
"A wrapper around OpenAI DALLE Image Generation. "
|
||||
"Useful for when you need to generate an image of"
|
||||
"people, places, paintings, animals, or other subjects. "
|
||||
"Input should be a text prompt to generate an image."
|
||||
)
|
||||
api_wrapper: DallEAPIWrapper
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the OpenAI DALLE Image Generation tool."""
|
||||
return self.api_wrapper.run(query)
|
@ -0,0 +1,15 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_community.tools.openai_dalle_image_generation import (
|
||||
OpenAIDALLEImageGenerationTool,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_image() -> None:
|
||||
"""Test OpenAI DALLE Image Generation."""
|
||||
mock_api_resource = MagicMock()
|
||||
# bypass pydantic validation as openai is not a package dependency
|
||||
tool = OpenAIDALLEImageGenerationTool.construct(api_wrapper=mock_api_resource)
|
||||
tool_input = {"query": "parrot on a branch"}
|
||||
result = tool.run(tool_input)
|
||||
assert result.startswith("https://")
|
Loading…
Reference in New Issue