You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/llms/aviary.py

137 lines
4.0 KiB
Python

"""Wrapper around Aviary"""
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
TIMEOUT = 60
class Aviary(LLM):
"""Allow you to use an Aviary.
Aviary is a backend for hosted models. You can
find out more about aviary at
http://github.com/ray-project/aviary
Has no dependencies, since it connects to backend
directly.
To get a list of the models supported on an
aviary, follow the instructions on the web site to
install the aviary CLI and then use:
`aviary models`
You must at least specify the environment
variable or parameter AVIARY_URL.
You may optionally specify the environment variable
or parameter AVIARY_TOKEN.
Example:
.. code-block:: python
from langchain.llms import Aviary
light = Aviary(aviary_url='AVIARY_URL',
model='amazon/LightGPT')
result = light.predict('How do you make fried rice?')
"""
model: str
aviary_url: str
aviary_token: str = Field("", exclude=True)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
if not aviary_url.endswith("/"):
aviary_url += "/"
values["aviary_url"] = aviary_url
aviary_token = get_from_dict_or_env(
values, "aviary_token", "AVIARY_TOKEN", default=""
)
values["aviary_token"] = aviary_token
aviary_endpoint = aviary_url + "models"
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
try:
response = requests.get(aviary_endpoint, headers=headers)
result = response.json()
# Confirm model is available
if values["model"] not in result:
raise ValueError(
f"{aviary_url} does not support model {values['model']}."
)
except requests.exceptions.RequestException as e:
raise ValueError(e)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"aviary_url": self.aviary_url,
"aviary_token": self.aviary_token,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "aviary"
@property
def headers(self) -> Dict[str, str]:
if self.aviary_token:
return {"Authorization": f"Bearer {self.aviary_token}"}
else:
return {}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Aviary
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = aviary("Tell me a joke.")
"""
url = self.aviary_url + "query/" + self.model.replace("/", "--")
response = requests.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=TIMEOUT,
)
try:
text = response.json()[self.model]["generated_text"]
except requests.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON from {url}. Text response: {response.text}",
) from e
if stop:
text = enforce_stop_tokens(text, stop)
return text