diff --git a/libs/langchain/langchain/utilities/dalle_image_generator.py b/libs/langchain/langchain/utilities/dalle_image_generator.py index dc25c44b73..f5d37b514a 100644 --- a/libs/langchain/langchain/utilities/dalle_image_generator.py +++ b/libs/langchain/langchain/utilities/dalle_image_generator.py @@ -8,29 +8,25 @@ from langchain.utils import get_from_dict_or_env class DallEAPIWrapper(BaseModel): """Wrapper for OpenAI's DALL-E Image Generator. - Docs for using: - 1. pip install openai + Usage instructions: + 1. `pip install openai` 2. save your OPENAI_API_KEY in an environment variable - """ client: Any #: :meta private: openai_api_key: Optional[str] = None - """number of images to generate""" n: int = 1 - """size of image to generate""" + """Number of images to generate""" size: str = "1024x1024" + """Size of image to generate""" + separator: str = "\n" + """Separator to use when multiple URLs are returned.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid - def _dalle_image_url(self, prompt: str) -> str: - params = {"prompt": prompt, "n": self.n, "size": self.size} - response = self.client.create(**params) - return response["data"][0]["url"] - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" @@ -42,19 +38,15 @@ class DallEAPIWrapper(BaseModel): openai.api_key = openai_api_key values["client"] = openai.Image - except ImportError: - raise ValueError( + except ImportError as e: + raise ImportError( "Could not import openai python package. " "Please it install it with `pip install openai`." - ) + ) from e return values def run(self, query: str) -> str: """Run query through OpenAI and parse result.""" - image_url = self._dalle_image_url(query) - - if image_url is None or image_url == "": - # We don't want to return the assumption alone if answer is empty - return "No image was generated" - else: - return image_url + response = self.client.create(prompt=query, n=self.n, size=self.size) + image_urls = self.separator.join([item["url"] for item in response["data"]]) + return image_urls if image_urls else "No image was generated"