From 9dead1034cbf143e784a1412e210116a76a62ba0 Mon Sep 17 00:00:00 2001 From: silvhua <111714881+silvhua@users.noreply.github.com> Date: Sun, 29 Oct 2023 14:23:23 -0700 Subject: [PATCH] `_dalle_image_url` returns list of urls if n>1 (#11800) - **Description:** Updated the `_dalle_image_url` method to return a list of URLs if self.n>1, - **Issue:** #10691, - **Dependencies:** unsure, - **Tag maintainer:** @eyurtsev, - **Twitter handle:** @silvhua --------- Co-authored-by: Bagatur --- .../utilities/dalle_image_generator.py | 32 +++++++------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/libs/langchain/langchain/utilities/dalle_image_generator.py b/libs/langchain/langchain/utilities/dalle_image_generator.py index dc25c44b73..f5d37b514a 100644 --- a/libs/langchain/langchain/utilities/dalle_image_generator.py +++ b/libs/langchain/langchain/utilities/dalle_image_generator.py @@ -8,29 +8,25 @@ from langchain.utils import get_from_dict_or_env class DallEAPIWrapper(BaseModel): """Wrapper for OpenAI's DALL-E Image Generator. - Docs for using: - 1. pip install openai + Usage instructions: + 1. `pip install openai` 2. save your OPENAI_API_KEY in an environment variable - """ client: Any #: :meta private: openai_api_key: Optional[str] = None - """number of images to generate""" n: int = 1 - """size of image to generate""" + """Number of images to generate""" size: str = "1024x1024" + """Size of image to generate""" + separator: str = "\n" + """Separator to use when multiple URLs are returned.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid - def _dalle_image_url(self, prompt: str) -> str: - params = {"prompt": prompt, "n": self.n, "size": self.size} - response = self.client.create(**params) - return response["data"][0]["url"] - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -42,19 +38,15 @@ class DallEAPIWrapper(BaseModel): openai.api_key = openai_api_key values["client"] = openai.Image - except ImportError: - raise ValueError( + except ImportError as e: + raise ImportError( "Could not import openai python package. " "Please it install it with `pip install openai`." - ) + ) from e return values def run(self, query: str) -> str: """Run query through OpenAI and parse result.""" - image_url = self._dalle_image_url(query) - - if image_url is None or image_url == "": - # We don't want to return the assumption alone if answer is empty - return "No image was generated" - else: - return image_url + response = self.client.create(prompt=query, n=self.n, size=self.size) + image_urls = self.separator.join([item["url"] for item in response["data"]]) + return image_urls if image_urls else "No image was generated"