mirror of https://github.com/hwchase17/langchain
Add aviary support (#5661)
Aviary is an open source toolkit for evaluating and deploying open source LLMs. You can find out more about it on [http://github.com/ray-project/aviary). You can try it out at [http://aviary.anyscale.com](aviary.anyscale.com). This code adds support for Aviary in LangChain. To minimize dependencies, it connects directly to the HTTP endpoint. The current implementation is not accelerated and uses the default implementation of `predict` and `generate`. It includes a test and a simple example. @hwchase17 and @agola11 could you have a look at this? --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>pull/5761/head
parent
a47c8618ec
commit
5124c1e0d9
@ -0,0 +1,103 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9597802c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Aviary\n",
|
||||||
|
"\n",
|
||||||
|
"[Aviary](https://www.anyscale.com/) is an open source tooklit for evaluating and deploying production open source LLMs. \n",
|
||||||
|
"\n",
|
||||||
|
"This example goes over how to use LangChain to interact with `Aviary`. You can try Aviary out [https://aviary.anyscale.com](here).\n",
|
||||||
|
"\n",
|
||||||
|
"You can find out more about Aviary at https://github.com/ray-project/aviary. \n",
|
||||||
|
"\n",
|
||||||
|
"One Aviary instance can serve multiple models. You can get a list of the available models by using the cli:\n",
|
||||||
|
"\n",
|
||||||
|
"`% aviary models`\n",
|
||||||
|
"\n",
|
||||||
|
"Or you can connect directly to the endpoint and get a list of available models by using the `/models` endpoint.\n",
|
||||||
|
"\n",
|
||||||
|
"The constructor requires a url for an Aviary backend, and optionally a token to validate the connection. \n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "6fb585dd",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from langchain.llms import Aviary\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "3fec5a59",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = Aviary(model='amazon/LightGPT', aviary_url=os.environ['AVIARY_URL'], aviary_token=os.environ['AVIARY_TOKEN'])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "4efd54dd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Love is an emotion that involves feelings of attraction, affection and empathy for another person. It can also refer to a deep bond between two people or groups of people. Love can be expressed in many different ways, such as through words, actions, gestures, music, art, literature, and other forms of communication.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"result = llm.predict('What is the meaning of love?')\n",
|
||||||
|
"print(result) "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "27e526b6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.15"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,136 @@
|
|||||||
|
"""Wrapper around Aviary"""
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
TIMEOUT = 60
|
||||||
|
|
||||||
|
|
||||||
|
class Aviary(LLM):
|
||||||
|
"""Allow you to use an Aviary.
|
||||||
|
|
||||||
|
Aviary is a backend for hosted models. You can
|
||||||
|
find out more about aviary at
|
||||||
|
http://github.com/ray-project/aviary
|
||||||
|
|
||||||
|
Has no dependencies, since it connects to backend
|
||||||
|
directly.
|
||||||
|
|
||||||
|
To get a list of the models supported on an
|
||||||
|
aviary, follow the instructions on the web site to
|
||||||
|
install the aviary CLI and then use:
|
||||||
|
`aviary models`
|
||||||
|
|
||||||
|
You must at least specify the environment
|
||||||
|
variable or parameter AVIARY_URL.
|
||||||
|
|
||||||
|
You may optionally specify the environment variable
|
||||||
|
or parameter AVIARY_TOKEN.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import Aviary
|
||||||
|
light = Aviary(aviary_url='AVIARY_URL',
|
||||||
|
model='amazon/LightGPT')
|
||||||
|
|
||||||
|
result = light.predict('How do you make fried rice?')
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
aviary_url: str
|
||||||
|
aviary_token: str = Field("", exclude=True)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
|
||||||
|
if not aviary_url.endswith("/"):
|
||||||
|
aviary_url += "/"
|
||||||
|
values["aviary_url"] = aviary_url
|
||||||
|
aviary_token = get_from_dict_or_env(
|
||||||
|
values, "aviary_token", "AVIARY_TOKEN", default=""
|
||||||
|
)
|
||||||
|
values["aviary_token"] = aviary_token
|
||||||
|
|
||||||
|
aviary_endpoint = aviary_url + "models"
|
||||||
|
headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
|
||||||
|
try:
|
||||||
|
response = requests.get(aviary_endpoint, headers=headers)
|
||||||
|
result = response.json()
|
||||||
|
# Confirm model is available
|
||||||
|
if values["model"] not in result:
|
||||||
|
raise ValueError(
|
||||||
|
f"{aviary_url} does not support model {values['model']}."
|
||||||
|
)
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"aviary_url": self.aviary_url,
|
||||||
|
"aviary_token": self.aviary_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "aviary"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def headers(self) -> Dict[str, str]:
|
||||||
|
if self.aviary_token:
|
||||||
|
return {"Authorization": f"Bearer {self.aviary_token}"}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Call out to Aviary
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = aviary("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
url = self.aviary_url + "query/" + self.model.replace("/", "--")
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=self.headers,
|
||||||
|
json={"prompt": prompt},
|
||||||
|
timeout=TIMEOUT,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
text = response.json()[self.model]["generated_text"]
|
||||||
|
except requests.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error decoding JSON from {url}. Text response: {response.text}",
|
||||||
|
) from e
|
||||||
|
if stop:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
@ -0,0 +1,10 @@
|
|||||||
|
"""Test Anyscale API wrapper."""
|
||||||
|
|
||||||
|
from langchain.llms.aviary import Aviary
|
||||||
|
|
||||||
|
|
||||||
|
def test_aviary_call() -> None:
|
||||||
|
"""Test valid call to Anyscale."""
|
||||||
|
llm = Aviary(model="test/model")
|
||||||
|
output = llm("Say bar:")
|
||||||
|
assert isinstance(output, str)
|
Loading…
Reference in New Issue