Add aviary support (#5661)

Aviary is an open source toolkit for evaluating and deploying open
source LLMs. You can find out more about it on
[http://github.com/ray-project/aviary). You can try it out at
[http://aviary.anyscale.com](aviary.anyscale.com).

This code adds support for Aviary in LangChain. To minimize
dependencies, it connects directly to the HTTP endpoint.

The current implementation is not accelerated and uses the default
implementation of `predict` and `generate`.

It includes a test and a simple example. 

@hwchase17 and @agola11 could you have a look at this?

---------

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
searx_updates
M Waleed Kadous 12 months ago committed by GitHub
parent a47c8618ec
commit 5124c1e0d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,103 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9597802c",
"metadata": {},
"source": [
"# Aviary\n",
"\n",
"[Aviary](https://www.anyscale.com/) is an open source tooklit for evaluating and deploying production open source LLMs. \n",
"\n",
"This example goes over how to use LangChain to interact with `Aviary`. You can try Aviary out [https://aviary.anyscale.com](here).\n",
"\n",
"You can find out more about Aviary at https://github.com/ray-project/aviary. \n",
"\n",
"One Aviary instance can serve multiple models. You can get a list of the available models by using the cli:\n",
"\n",
"`% aviary models`\n",
"\n",
"Or you can connect directly to the endpoint and get a list of available models by using the `/models` endpoint.\n",
"\n",
"The constructor requires a url for an Aviary backend, and optionally a token to validate the connection. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6fb585dd",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"from langchain.llms import Aviary\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3fec5a59",
"metadata": {},
"outputs": [],
"source": [
"llm = Aviary(model='amazon/LightGPT', aviary_url=os.environ['AVIARY_URL'], aviary_token=os.environ['AVIARY_TOKEN'])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4efd54dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Love is an emotion that involves feelings of attraction, affection and empathy for another person. It can also refer to a deep bond between two people or groups of people. Love can be expressed in many different ways, such as through words, actions, gestures, music, art, literature, and other forms of communication.\n"
]
}
],
"source": [
"result = llm.predict('What is the meaning of love?')\n",
"print(result) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27e526b6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -5,6 +5,7 @@ from langchain.llms.ai21 import AI21
from langchain.llms.aleph_alpha import AlephAlpha
from langchain.llms.anthropic import Anthropic
from langchain.llms.anyscale import Anyscale
from langchain.llms.aviary import Aviary
from langchain.llms.bananadev import Banana
from langchain.llms.base import BaseLLM
from langchain.llms.beam import Beam
@ -47,6 +48,7 @@ __all__ = [
"Anthropic",
"AlephAlpha",
"Anyscale",
"Aviary",
"Banana",
"Beam",
"Bedrock",
@ -94,6 +96,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"aleph_alpha": AlephAlpha,
"anthropic": Anthropic,
"anyscale": Anyscale,
"aviary": Aviary,
"bananadev": Banana,
"beam": Beam,
"cerebriumai": CerebriumAI,

@ -0,0 +1,136 @@
"""Wrapper around Aviary"""
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
TIMEOUT = 60
class Aviary(LLM):
"""Allow you to use an Aviary.
Aviary is a backend for hosted models. You can
find out more about aviary at
http://github.com/ray-project/aviary
Has no dependencies, since it connects to backend
directly.
To get a list of the models supported on an
aviary, follow the instructions on the web site to
install the aviary CLI and then use:
`aviary models`
You must at least specify the environment
variable or parameter AVIARY_URL.
You may optionally specify the environment variable
or parameter AVIARY_TOKEN.
Example:
.. code-block:: python
from langchain.llms import Aviary
light = Aviary(aviary_url='AVIARY_URL',
model='amazon/LightGPT')
result = light.predict('How do you make fried rice?')
"""
model: str
aviary_url: str
aviary_token: str = Field("", exclude=True)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
if not aviary_url.endswith("/"):
aviary_url += "/"
values["aviary_url"] = aviary_url
aviary_token = get_from_dict_or_env(
values, "aviary_token", "AVIARY_TOKEN", default=""
)
values["aviary_token"] = aviary_token
aviary_endpoint = aviary_url + "models"
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
try:
response = requests.get(aviary_endpoint, headers=headers)
result = response.json()
# Confirm model is available
if values["model"] not in result:
raise ValueError(
f"{aviary_url} does not support model {values['model']}."
)
except requests.exceptions.RequestException as e:
raise ValueError(e)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"aviary_url": self.aviary_url,
"aviary_token": self.aviary_token,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "aviary"
@property
def headers(self) -> Dict[str, str]:
if self.aviary_token:
return {"Authorization": f"Bearer {self.aviary_token}"}
else:
return {}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Call out to Aviary
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = aviary("Tell me a joke.")
"""
url = self.aviary_url + "query/" + self.model.replace("/", "--")
response = requests.post(
url,
headers=self.headers,
json={"prompt": prompt},
timeout=TIMEOUT,
)
try:
text = response.json()[self.model]["generated_text"]
except requests.JSONDecodeError as e:
raise ValueError(
f"Error decoding JSON from {url}. Text response: {response.text}",
) from e
if stop:
text = enforce_stop_tokens(text, stop)
return text

@ -0,0 +1,10 @@
"""Test Anyscale API wrapper."""
from langchain.llms.aviary import Aviary
def test_aviary_call() -> None:
"""Test valid call to Anyscale."""
llm = Aviary(model="test/model")
output = llm("Say bar:")
assert isinstance(output, str)
Loading…
Cancel
Save