mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
7cf2d2759d
Added missed docstrings. Format docstings to the consistent form.
266 lines
9.2 KiB
Python
266 lines
9.2 KiB
Python
"""Wrapper around EdenAI's Generation API."""
|
|
import logging
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
from aiohttp import ClientSession
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.pydantic_v1 import Extra, Field, root_validator
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
from langchain_community.utilities.requests import Requests
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EdenAI(LLM):
|
|
"""EdenAI models.
|
|
|
|
To use, you should have
|
|
the environment variable ``EDENAI_API_KEY`` set with your API token.
|
|
You can find your token here: https://app.edenai.run/admin/account/settings
|
|
|
|
`feature` and `subfeature` are required, but any other model parameters can also be
|
|
passed in with the format params={model_param: value, ...}
|
|
|
|
for api reference check edenai documentation: http://docs.edenai.co.
|
|
"""
|
|
|
|
base_url: str = "https://api.edenai.run/v2"
|
|
|
|
edenai_api_key: Optional[str] = None
|
|
|
|
feature: Literal["text", "image"] = "text"
|
|
"""Which generative feature to use, use text by default"""
|
|
|
|
subfeature: Literal["generation"] = "generation"
|
|
"""Subfeature of above feature, use generation by default"""
|
|
|
|
provider: str
|
|
"""Generative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
|
|
|
|
model: Optional[str] = None
|
|
"""
|
|
model name for above provider (eg: 'gpt-3.5-turbo-instruct' for openai)
|
|
available models are shown on https://docs.edenai.co/ under 'available providers'
|
|
"""
|
|
|
|
# Optional parameters to add depending of chosen feature
|
|
# see api reference for more infos
|
|
temperature: Optional[float] = Field(default=None, ge=0, le=1) # for text
|
|
max_tokens: Optional[int] = Field(default=None, ge=0) # for text
|
|
resolution: Optional[Literal["256x256", "512x512", "1024x1024"]] = None # for image
|
|
|
|
params: Dict[str, Any] = Field(default_factory=dict)
|
|
"""
|
|
DEPRECATED: use temperature, max_tokens, resolution directly
|
|
optional parameters to pass to api
|
|
"""
|
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""extra parameters"""
|
|
|
|
stop_sequences: Optional[List[str]] = None
|
|
"""Stop sequences to use."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key exists in environment."""
|
|
values["edenai_api_key"] = get_from_dict_or_env(
|
|
values, "edenai_api_key", "EDENAI_API_KEY"
|
|
)
|
|
return values
|
|
|
|
@root_validator(pre=True)
|
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Build extra kwargs from additional params that were passed in."""
|
|
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
|
|
|
extra = values.get("model_kwargs", {})
|
|
for field_name in list(values):
|
|
if field_name not in all_required_field_names:
|
|
if field_name in extra:
|
|
raise ValueError(f"Found {field_name} supplied twice.")
|
|
logger.warning(
|
|
f"""{field_name} was transferred to model_kwargs.
|
|
Please confirm that {field_name} is what you intended."""
|
|
)
|
|
extra[field_name] = values.pop(field_name)
|
|
values["model_kwargs"] = extra
|
|
return values
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of model."""
|
|
return "edenai"
|
|
|
|
def _format_output(self, output: dict) -> str:
|
|
if self.feature == "text":
|
|
return output[self.provider]["generated_text"]
|
|
else:
|
|
return output[self.provider]["items"][0]["image"]
|
|
|
|
@staticmethod
|
|
def get_user_agent() -> str:
|
|
from langchain_community import __version__
|
|
|
|
return f"langchain/{__version__}"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Call out to EdenAI's text generation endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
|
|
Returns:
|
|
json formatted str response.
|
|
"""
|
|
stops = None
|
|
if self.stop_sequences is not None and stop is not None:
|
|
raise ValueError(
|
|
"stop sequences found in both the input and default params."
|
|
)
|
|
elif self.stop_sequences is not None:
|
|
stops = self.stop_sequences
|
|
else:
|
|
stops = stop
|
|
|
|
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
|
headers = {
|
|
"Authorization": f"Bearer {self.edenai_api_key}",
|
|
"User-Agent": self.get_user_agent(),
|
|
}
|
|
payload: Dict[str, Any] = {
|
|
"providers": self.provider,
|
|
"text": prompt,
|
|
"max_tokens": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"resolution": self.resolution,
|
|
**self.params,
|
|
**kwargs,
|
|
"num_images": 1, # always limit to 1 (ignored for text)
|
|
}
|
|
|
|
# filter None values to not pass them to the http payload
|
|
payload = {k: v for k, v in payload.items() if v is not None}
|
|
|
|
if self.model is not None:
|
|
payload["settings"] = {self.provider: self.model}
|
|
|
|
request = Requests(headers=headers)
|
|
response = request.post(url=url, data=payload)
|
|
|
|
if response.status_code >= 500:
|
|
raise Exception(f"EdenAI Server: Error {response.status_code}")
|
|
elif response.status_code >= 400:
|
|
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
|
|
elif response.status_code != 200:
|
|
raise Exception(
|
|
f"EdenAI returned an unexpected response with status "
|
|
f"{response.status_code}: {response.text}"
|
|
)
|
|
|
|
data = response.json()
|
|
provider_response = data[self.provider]
|
|
if provider_response.get("status") == "fail":
|
|
err_msg = provider_response.get("error", {}).get("message")
|
|
raise Exception(err_msg)
|
|
|
|
output = self._format_output(data)
|
|
|
|
if stops is not None:
|
|
output = enforce_stop_tokens(output, stops)
|
|
|
|
return output
|
|
|
|
async def _acall(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Call EdenAi model to get predictions based on the prompt.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: A list of stop words (optional).
|
|
run_manager: A callback manager for async interaction with LLMs.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
"""
|
|
|
|
stops = None
|
|
if self.stop_sequences is not None and stop is not None:
|
|
raise ValueError(
|
|
"stop sequences found in both the input and default params."
|
|
)
|
|
elif self.stop_sequences is not None:
|
|
stops = self.stop_sequences
|
|
else:
|
|
stops = stop
|
|
|
|
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
|
headers = {
|
|
"Authorization": f"Bearer {self.edenai_api_key}",
|
|
"User-Agent": self.get_user_agent(),
|
|
}
|
|
payload: Dict[str, Any] = {
|
|
"providers": self.provider,
|
|
"text": prompt,
|
|
"max_tokens": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"resolution": self.resolution,
|
|
**self.params,
|
|
**kwargs,
|
|
"num_images": 1, # always limit to 1 (ignored for text)
|
|
}
|
|
|
|
# filter `None` values to not pass them to the http payload as null
|
|
payload = {k: v for k, v in payload.items() if v is not None}
|
|
|
|
if self.model is not None:
|
|
payload["settings"] = {self.provider: self.model}
|
|
|
|
async with ClientSession() as session:
|
|
async with session.post(url, json=payload, headers=headers) as response:
|
|
if response.status >= 500:
|
|
raise Exception(f"EdenAI Server: Error {response.status}")
|
|
elif response.status >= 400:
|
|
raise ValueError(
|
|
f"EdenAI received an invalid payload: {response.text}"
|
|
)
|
|
elif response.status != 200:
|
|
raise Exception(
|
|
f"EdenAI returned an unexpected response with status "
|
|
f"{response.status}: {response.text}"
|
|
)
|
|
|
|
response_json = await response.json()
|
|
provider_response = response_json[self.provider]
|
|
if provider_response.get("status") == "fail":
|
|
err_msg = provider_response.get("error", {}).get("message")
|
|
raise Exception(err_msg)
|
|
|
|
output = self._format_output(response_json)
|
|
if stops is not None:
|
|
output = enforce_stop_tokens(output, stops)
|
|
|
|
return output
|