`_dalle_image_url` returns list of urls if n>1 (#11800)

- **Description:** Updated the `_dalle_image_url` method to return a
list of URLs if self.n>1,
  - **Issue:** #10691,
  - **Dependencies:** unsure,
  - **Tag maintainer:** @eyurtsev,
  - **Twitter handle:** @silvhua
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/12525/head^2
silvhua 8 months ago committed by GitHub
parent 1815ea2fdb
commit 9dead1034c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,29 +8,25 @@ from langchain.utils import get_from_dict_or_env
class DallEAPIWrapper(BaseModel):
"""Wrapper for OpenAI's DALL-E Image Generator.
Docs for using:
1. pip install openai
Usage instructions:
1. `pip install openai`
2. save your OPENAI_API_KEY in an environment variable
"""
client: Any #: :meta private:
openai_api_key: Optional[str] = None
"""number of images to generate"""
n: int = 1
"""size of image to generate"""
"""Number of images to generate"""
size: str = "1024x1024"
"""Size of image to generate"""
separator: str = "\n"
"""Separator to use when multiple URLs are returned."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def _dalle_image_url(self, prompt: str) -> str:
params = {"prompt": prompt, "n": self.n, "size": self.size}
response = self.client.create(**params)
return response["data"][0]["url"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
@ -42,19 +38,15 @@ class DallEAPIWrapper(BaseModel):
openai.api_key = openai_api_key
values["client"] = openai.Image
except ImportError:
raise ValueError(
except ImportError as e:
raise ImportError(
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
) from e
return values
def run(self, query: str) -> str:
"""Run query through OpenAI and parse result."""
image_url = self._dalle_image_url(query)
if image_url is None or image_url == "":
# We don't want to return the assumption alone if answer is empty
return "No image was generated"
else:
return image_url
response = self.client.create(prompt=query, n=self.n, size=self.size)
image_urls = self.separator.join([item["url"] for item in response["data"]])
return image_urls if image_urls else "No image was generated"

Loading…
Cancel
Save