From 5124c1e0d9cba6906dbfada53cfe55489f723c5a Mon Sep 17 00:00:00 2001 From: M Waleed Kadous Date: Mon, 5 Jun 2023 16:28:42 -0700 Subject: [PATCH] Add aviary support (#5661) Aviary is an open source toolkit for evaluating and deploying open source LLMs. You can find out more about it on [http://github.com/ray-project/aviary). You can try it out at [http://aviary.anyscale.com](aviary.anyscale.com). This code adds support for Aviary in LangChain. To minimize dependencies, it connects directly to the HTTP endpoint. The current implementation is not accelerated and uses the default implementation of `predict` and `generate`. It includes a test and a simple example. @hwchase17 and @agola11 could you have a look at this? --------- Co-authored-by: Dev 2049 --- .../models/llms/integrations/aviary.ipynb | 103 +++++++++++++ langchain/llms/__init__.py | 3 + langchain/llms/aviary.py | 136 ++++++++++++++++++ tests/integration_tests/llms/test_aviary.py | 10 ++ 4 files changed, 252 insertions(+) create mode 100644 docs/modules/models/llms/integrations/aviary.ipynb create mode 100644 langchain/llms/aviary.py create mode 100644 tests/integration_tests/llms/test_aviary.py diff --git a/docs/modules/models/llms/integrations/aviary.ipynb b/docs/modules/models/llms/integrations/aviary.ipynb new file mode 100644 index 00000000..397f23a7 --- /dev/null +++ b/docs/modules/models/llms/integrations/aviary.ipynb @@ -0,0 +1,103 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9597802c", + "metadata": {}, + "source": [ + "# Aviary\n", + "\n", + "[Aviary](https://www.anyscale.com/) is an open source tooklit for evaluating and deploying production open source LLMs. \n", + "\n", + "This example goes over how to use LangChain to interact with `Aviary`. You can try Aviary out [https://aviary.anyscale.com](here).\n", + "\n", + "You can find out more about Aviary at https://github.com/ray-project/aviary. \n", + "\n", + "One Aviary instance can serve multiple models. You can get a list of the available models by using the cli:\n", + "\n", + "`% aviary models`\n", + "\n", + "Or you can connect directly to the endpoint and get a list of available models by using the `/models` endpoint.\n", + "\n", + "The constructor requires a url for an Aviary backend, and optionally a token to validate the connection. \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6fb585dd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "from langchain.llms import Aviary\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3fec5a59", + "metadata": {}, + "outputs": [], + "source": [ + "llm = Aviary(model='amazon/LightGPT', aviary_url=os.environ['AVIARY_URL'], aviary_token=os.environ['AVIARY_TOKEN'])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4efd54dd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Love is an emotion that involves feelings of attraction, affection and empathy for another person. It can also refer to a deep bond between two people or groups of people. Love can be expressed in many different ways, such as through words, actions, gestures, music, art, literature, and other forms of communication.\n" + ] + } + ], + "source": [ + "result = llm.predict('What is the meaning of love?')\n", + "print(result) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27e526b6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + }, + "vscode": { + "interpreter": { + "hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 9dc97c56..e551a23f 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -5,6 +5,7 @@ from langchain.llms.ai21 import AI21 from langchain.llms.aleph_alpha import AlephAlpha from langchain.llms.anthropic import Anthropic from langchain.llms.anyscale import Anyscale +from langchain.llms.aviary import Aviary from langchain.llms.bananadev import Banana from langchain.llms.base import BaseLLM from langchain.llms.beam import Beam @@ -47,6 +48,7 @@ __all__ = [ "Anthropic", "AlephAlpha", "Anyscale", + "Aviary", "Banana", "Beam", "Bedrock", @@ -94,6 +96,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "aleph_alpha": AlephAlpha, "anthropic": Anthropic, "anyscale": Anyscale, + "aviary": Aviary, "bananadev": Banana, "beam": Beam, "cerebriumai": CerebriumAI, diff --git a/langchain/llms/aviary.py b/langchain/llms/aviary.py new file mode 100644 index 00000000..6f4a48a5 --- /dev/null +++ b/langchain/llms/aviary.py @@ -0,0 +1,136 @@ +"""Wrapper around Aviary""" +from typing import Any, Dict, List, Mapping, Optional + +import requests +from pydantic import Extra, Field, root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens +from langchain.utils import get_from_dict_or_env + +TIMEOUT = 60 + + +class Aviary(LLM): + """Allow you to use an Aviary. + + Aviary is a backend for hosted models. You can + find out more about aviary at + http://github.com/ray-project/aviary + + Has no dependencies, since it connects to backend + directly. + + To get a list of the models supported on an + aviary, follow the instructions on the web site to + install the aviary CLI and then use: + `aviary models` + + You must at least specify the environment + variable or parameter AVIARY_URL. + + You may optionally specify the environment variable + or parameter AVIARY_TOKEN. + + Example: + .. code-block:: python + + from langchain.llms import Aviary + light = Aviary(aviary_url='AVIARY_URL', + model='amazon/LightGPT') + + result = light.predict('How do you make fried rice?') + """ + + model: str + aviary_url: str + aviary_token: str = Field("", exclude=True) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL") + if not aviary_url.endswith("/"): + aviary_url += "/" + values["aviary_url"] = aviary_url + aviary_token = get_from_dict_or_env( + values, "aviary_token", "AVIARY_TOKEN", default="" + ) + values["aviary_token"] = aviary_token + + aviary_endpoint = aviary_url + "models" + headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {} + try: + response = requests.get(aviary_endpoint, headers=headers) + result = response.json() + # Confirm model is available + if values["model"] not in result: + raise ValueError( + f"{aviary_url} does not support model {values['model']}." + ) + + except requests.exceptions.RequestException as e: + raise ValueError(e) + + return values + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "aviary_url": self.aviary_url, + "aviary_token": self.aviary_token, + } + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "aviary" + + @property + def headers(self) -> Dict[str, str]: + if self.aviary_token: + return {"Authorization": f"Bearer {self.aviary_token}"} + else: + return {} + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + """Call out to Aviary + Args: + prompt: The prompt to pass into the model. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = aviary("Tell me a joke.") + """ + url = self.aviary_url + "query/" + self.model.replace("/", "--") + response = requests.post( + url, + headers=self.headers, + json={"prompt": prompt}, + timeout=TIMEOUT, + ) + try: + text = response.json()[self.model]["generated_text"] + except requests.JSONDecodeError as e: + raise ValueError( + f"Error decoding JSON from {url}. Text response: {response.text}", + ) from e + if stop: + text = enforce_stop_tokens(text, stop) + return text diff --git a/tests/integration_tests/llms/test_aviary.py b/tests/integration_tests/llms/test_aviary.py new file mode 100644 index 00000000..d2d67fb3 --- /dev/null +++ b/tests/integration_tests/llms/test_aviary.py @@ -0,0 +1,10 @@ +"""Test Anyscale API wrapper.""" + +from langchain.llms.aviary import Aviary + + +def test_aviary_call() -> None: + """Test valid call to Anyscale.""" + llm = Aviary(model="test/model") + output = llm("Say bar:") + assert isinstance(output, str)