langchain/libs/community/langchain_community/utilities/dalle_image_generator.py
Erick Friis c2a3021bb0
multiple: pydantic 2 compatibility, v0.3 (#26443)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Dan O'Donovan <dan.odonovan@gmail.com>
Co-authored-by: Tom Daniel Grande <tomdgrande@gmail.com>
Co-authored-by: Grande <Tom.Daniel.Grande@statsbygg.no>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
Co-authored-by: Tomaz Bratanic <bratanic.tomaz@gmail.com>
Co-authored-by: ZhangShenao <15201440436@163.com>
Co-authored-by: Friso H. Kingma <fhkingma@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Morgante Pell <morgantep@google.com>
2024-09-13 14:38:45 -07:00

158 lines
5.8 KiB
Python

"""Utility that calls OpenAI's Dall-E Image Generator."""
import logging
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from langchain_core.utils import (
from_env,
get_pydantic_field_names,
secret_from_env,
)
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self
from langchain_community.utils.openai import is_openai_v1
logger = logging.getLogger(__name__)
class DallEAPIWrapper(BaseModel):
"""Wrapper for OpenAI's DALL-E Image Generator.
https://platform.openai.com/docs/guides/images/generations?context=node
Usage instructions:
1. `pip install openai`
2. save your OPENAI_API_KEY in an environment variable
"""
client: Any = None #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="dall-e-2", alias="model")
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
openai_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
"OPENAI_API_KEY",
default=None,
),
)
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(
alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None)
)
"""Base URL path for API requests, leave blank if not using a proxy or service
emulator."""
openai_organization: Optional[str] = Field(
alias="organization",
default_factory=from_env(
["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None
),
)
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
# to support explicit proxy for OpenAI
openai_proxy: str = Field(default_factory=from_env("OPENAI_PROXY", default=""))
request_timeout: Union[float, Tuple[float, float], Any, None] = Field(
default=None, alias="timeout"
)
n: int = 1
"""Number of images to generate"""
size: str = "1024x1024"
"""Size of image to generate"""
separator: str = "\n"
"""Separator to use when multiple URLs are returned."""
quality: Optional[str] = "standard"
"""Quality of the image that will be generated"""
max_retries: int = 2
"""Maximum number of retries to make when generating."""
default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: Union[Any, None] = None
"""Optional httpx.Client."""
model_config = ConfigDict(extra="forbid", protected_namespaces=())
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key and python package exists in environment."""
try:
import openai
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if is_openai_v1():
client_params = {
"api_key": self.openai_api_key,
"organization": self.openai_organization,
"base_url": self.openai_api_base,
"timeout": self.request_timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
"http_client": self.http_client,
}
if not self.client:
self.client = openai.OpenAI(**client_params).images
if not self.async_client:
self.async_client = openai.AsyncOpenAI(**client_params).images
elif not self.client:
self.client = openai.Image
else:
pass
return self
def run(self, query: str) -> str:
"""Run query through OpenAI and parse result."""
if is_openai_v1():
response = self.client.generate(
prompt=query,
n=self.n,
size=self.size,
model=self.model_name,
quality=self.quality,
)
image_urls = self.separator.join([item.url for item in response.data])
else:
response = self.client.create(
prompt=query, n=self.n, size=self.size, model=self.model_name
)
image_urls = self.separator.join([item["url"] for item in response["data"]])
return image_urls if image_urls else "No image was generated"