mirror of https://github.com/hwchase17/langchain
gradient.ai LLM intregration (#10800)
- **Description:** This PR implements a new LLM API to https://gradient.ai - **Issue:** Feature request for LLM #10745 - **Dependencies**: No additional dependencies are introduced. - **Tag maintainer:** I am opening this PR for visibility, once ready for review I'll tag. - ```make format && make lint && make test``` is running. - added a `integration` and `mock unit` test. Co-authored-by: michaelfeil <me@michaelfeil.eu> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/10891/head
parent
5097007407
commit
55570e54e1
@ -0,0 +1,216 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gradient\n",
|
||||
"\n",
|
||||
"`Gradient` allows to fine tune and get completions on LLMs with a simple web API.\n",
|
||||
"\n",
|
||||
"This notebook goes over how to use Langchain with [Gradient](https://gradient.ai/).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Imports"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import requests\n",
|
||||
"from langchain.llms import GradientLLM\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set the Environment API Key\n",
|
||||
"Make sure to get your API key from Gradient AI. You are given $10 in free credits to test and fine-tune different models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n",
|
||||
" # Access token under https://auth.gradient.ai/select-workspace\n",
|
||||
" os.environ[\"GRADIENT_ACCESS_TOKEN\"] = getpass(\"gradient.ai access token:\")\n",
|
||||
"if not os.environ.get(\"GRADIENT_WORKSPACE_ID\",None):\n",
|
||||
" # `ID` listed in `$ gradient workspace list`\n",
|
||||
" # also displayed after login at at https://auth.gradient.ai/select-workspace\n",
|
||||
" os.environ[\"GRADIENT_WORKSPACE_ID\"] = getpass(\"gradient.ai workspace id:\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Optional: Validate your Enviroment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Credentials valid.\n",
|
||||
"Possible values for `model_id` are:\n",
|
||||
" {'models': [{'id': '99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model', 'name': 'bloom-560m', 'slug': 'bloom-560m', 'type': 'baseModel'}, {'id': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'name': 'llama2-7b-chat', 'slug': 'llama2-7b-chat', 'type': 'baseModel'}, {'id': 'cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model', 'name': 'nous-hermes2', 'slug': 'nous-hermes2', 'type': 'baseModel'}, {'baseModelId': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'id': 'bb7b9865-0ce3-41a8-8e2b-5cbcbe1262eb_model_adapter', 'name': 'optical-transmitting-sensor', 'type': 'modelAdapter'}]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"resp = requests.get(f'https://api.gradient.ai/api/models', headers={\n",
|
||||
" \"authorization\": f\"Bearer {os.environ['GRADIENT_ACCESS_TOKEN']}\",\n",
|
||||
" \"x-gradient-workspace-id\": f\"{os.environ['GRADIENT_WORKSPACE_ID']}\",\n",
|
||||
" },\n",
|
||||
" )\n",
|
||||
"if resp.status_code == 200:\n",
|
||||
" models = resp.json()\n",
|
||||
" print(\"Credentials valid.\\nPossible values for `model_id` are:\\n\", models)\n",
|
||||
"else:\n",
|
||||
" print(\"Error when listing models. Are your credentials valid?\", resp.text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create the Gradient instance\n",
|
||||
"You can specify different parameters such as the model name, max tokens generated, temperature, etc."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = GradientLLM(\n",
|
||||
" # `ID` listed in `$ gradient model list`\n",
|
||||
" model_id=\"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\",\n",
|
||||
" # # optional: set new credentials, they default to environment variables\n",
|
||||
" # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
|
||||
" # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create a Prompt Template\n",
|
||||
"We will create a prompt template for Question and Answer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Initiate the LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run the LLMChain\n",
|
||||
"Provide a question and run the LLMChain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' The first team to win the Super Bowl was the New England Patriots. The Patriots won the'"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"What NFL team won the Super Bowl in 1994?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(\n",
|
||||
" question=question\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
@ -0,0 +1,236 @@
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class GradientLLM(LLM):
|
||||
"""Gradient.ai LLM Endpoints.
|
||||
|
||||
GradientLLM is a class to interact with LLMs on gradient.ai
|
||||
|
||||
To use, set the environment variable ``GRADIENT_ACCESS_TOKEN`` with your
|
||||
API token and ``GRADIENT_WORKSPACE_ID`` for your gradient workspace,
|
||||
or alternatively provide them as keywords to the constructor of this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.gradientai_endpoint import GradientAIEndpoint
|
||||
GradientLLM(
|
||||
model_id="cad6644_base_ml_model",
|
||||
model_kwargs={
|
||||
"max_generated_token_count": 200,
|
||||
"temperature": 0.75,
|
||||
"top_p": 0.95,
|
||||
"top_k": 20,
|
||||
"stop": [],
|
||||
},
|
||||
gradient_workspace_id="12345614fc0_workspace",
|
||||
gradient_access_token="gradientai-access_token",
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
model_id: str
|
||||
"Underlying gradient.ai model id (base or fine-tuned)."
|
||||
|
||||
gradient_workspace_id: Optional[str] = None
|
||||
"Underlying gradient.ai workspace_id."
|
||||
|
||||
gradient_access_token: Optional[str] = None
|
||||
"""gradient.ai API Token, which can be generated by going to
|
||||
https://auth.gradient.ai/select-workspace
|
||||
and selecting "Access tokens" under the profile drop-down.
|
||||
"""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
|
||||
gradient_api_url: str = "https://api.gradient.ai/api"
|
||||
"""Endpoint URL to use."""
|
||||
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
"""ClientSession, in case we want to reuse connection for better performance."""
|
||||
|
||||
# LLM call kwargs
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
values["gradient_access_token"] = get_from_dict_or_env(
|
||||
values, "gradient_access_token", "GRADIENT_ACCESS_TOKEN"
|
||||
)
|
||||
values["gradient_workspace_id"] = get_from_dict_or_env(
|
||||
values, "gradient_workspace_id", "GRADIENT_WORKSPACE_ID"
|
||||
)
|
||||
|
||||
if (
|
||||
values["gradient_access_token"] is None
|
||||
or len(values["gradient_access_token"]) < 10
|
||||
):
|
||||
raise ValueError("env variable `GRADIENT_ACCESS_TOKEN` must be set")
|
||||
|
||||
if (
|
||||
values["gradient_workspace_id"] is None
|
||||
or len(values["gradient_access_token"]) < 3
|
||||
):
|
||||
raise ValueError("env variable `GRADIENT_WORKSPACE_ID` must be set")
|
||||
|
||||
if values["model_kwargs"]:
|
||||
kw = values["model_kwargs"]
|
||||
if not 0 <= kw.get("temperature", 0.5) <= 1:
|
||||
raise ValueError("`temperature` must be in the range [0.0, 1.0]")
|
||||
|
||||
if not 0 <= kw.get("top_p", 0.5) <= 1:
|
||||
raise ValueError("`top_p` must be in the range [0.0, 1.0]")
|
||||
|
||||
if 0 >= kw.get("top_k", 0.5):
|
||||
raise ValueError("`top_k` must be positive")
|
||||
|
||||
if 0 >= kw.get("max_generated_token_count", 1):
|
||||
raise ValueError("`max_generated_token_count` must be positive")
|
||||
|
||||
values["gradient_api_url"] = get_from_dict_or_env(
|
||||
values, "gradient_api_url", "GRADIENT_API_URL"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"gradient_api_url": self.gradient_api_url},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "gradient"
|
||||
|
||||
def _kwargs_post_request(
|
||||
self, prompt: str, kwargs: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
"""Build the kwargs for the Post request, used by sync
|
||||
|
||||
Args:
|
||||
prompt (str): prompt used in query
|
||||
kwargs (dict): model kwargs in payload
|
||||
|
||||
Returns:
|
||||
Dict[str, Union[str,dict]]: _description_
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_params = {**_model_kwargs, **kwargs}
|
||||
|
||||
return dict(
|
||||
url=f"{self.gradient_api_url}/models/{self.model_id}/complete",
|
||||
headers={
|
||||
"authorization": f"Bearer {self.gradient_access_token}",
|
||||
"x-gradient-workspace-id": f"{self.gradient_workspace_id}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
json=dict(
|
||||
query=prompt,
|
||||
maxGeneratedTokenCount=_params.get("max_generated_token_count", None),
|
||||
temperature=_params.get("temperature", None),
|
||||
topK=_params.get("top_k", None),
|
||||
topP=_params.get("top_p", None),
|
||||
),
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call to Gradients API `model/{id}/complete`.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
try:
|
||||
response = requests.post(**self._kwargs_post_request(prompt, kwargs))
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f"RequestException while calling Gradient Endpoint: {e}")
|
||||
|
||||
text = response.json()["generatedOutput"]
|
||||
|
||||
if stop is not None:
|
||||
# Apply stop tokens when making calls to Gradient
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Union[List[str], None] = None,
|
||||
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Async Call to Gradients API `model/{id}/complete`.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
**self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
text = (await response.json())["generatedOutput"]
|
||||
else:
|
||||
async with self.aiosession.post(
|
||||
**self._kwargs_post_request(prompt=prompt, kwargs=kwargs)
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(
|
||||
f"Gradient returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
text = (await response.json())["generatedOutput"]
|
||||
|
||||
if stop is not None:
|
||||
# Apply stop tokens when making calls to Gradient
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
@ -0,0 +1,36 @@
|
||||
"""Test GradientAI API wrapper.
|
||||
|
||||
In order to run this test, you need to have an GradientAI api key.
|
||||
You can get it by registering for free at https://gradient.ai/.
|
||||
|
||||
You'll then need to set:
|
||||
- `GRADIENT_ACCESS_TOKEN` environment variable to your api key.
|
||||
- `GRADIENT_WORKSPACE_ID` environment variable to your workspace id.
|
||||
- `GRADIENT_MODEL_ID` environment variable to your workspace id.
|
||||
"""
|
||||
import os
|
||||
|
||||
from langchain.llms import GradientLLM
|
||||
|
||||
|
||||
def test_gradient_acall() -> None:
|
||||
"""Test simple call to gradient.ai."""
|
||||
model_id = os.environ["GRADIENT_MODEL_ID"]
|
||||
llm = GradientLLM(model_id=model_id)
|
||||
output = llm("Say hello:", temperature=0.2, max_tokens=250)
|
||||
|
||||
assert llm._llm_type == "gradient"
|
||||
|
||||
assert isinstance(output, str)
|
||||
assert len(output)
|
||||
|
||||
|
||||
async def test_gradientai_acall() -> None:
|
||||
"""Test async call to gradient.ai."""
|
||||
model_id = os.environ["GRADIENT_MODEL_ID"]
|
||||
llm = GradientLLM(model_id=model_id)
|
||||
output = await llm.agenerate(["Say hello:"], temperature=0.2, max_tokens=250)
|
||||
assert llm._llm_type == "gradient"
|
||||
|
||||
assert isinstance(output, str)
|
||||
assert len(output)
|
@ -0,0 +1,56 @@
|
||||
from typing import Dict
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from langchain.llms import GradientLLM
|
||||
|
||||
_MODEL_ID = "my_model_valid_id"
|
||||
_GRADIENT_SECRET = "secret_valid_token_123456"
|
||||
_GRADIENT_WORKSPACE_ID = "valid_workspace_12345"
|
||||
_GRADIENT_BASE_URL = "https://api.gradient.ai/api"
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data: Dict, status_code: int):
|
||||
self.json_data = json_data
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self) -> Dict:
|
||||
return self.json_data
|
||||
|
||||
|
||||
def mocked_requests_post(
|
||||
url: str,
|
||||
headers: dict,
|
||||
json: dict,
|
||||
) -> MockResponse:
|
||||
assert url.startswith(_GRADIENT_BASE_URL)
|
||||
assert headers
|
||||
assert json
|
||||
|
||||
return MockResponse(
|
||||
json_data={"generatedOutput": "bar"},
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
def test_gradient_llm_sync(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch("requests.post", side_effect=mocked_requests_post)
|
||||
|
||||
llm = GradientLLM(
|
||||
gradient_api_url=_GRADIENT_BASE_URL,
|
||||
gradient_access_token=_GRADIENT_SECRET,
|
||||
gradient_workspace_id=_GRADIENT_WORKSPACE_ID,
|
||||
model_id=_MODEL_ID,
|
||||
)
|
||||
assert llm.gradient_access_token == _GRADIENT_SECRET
|
||||
assert llm.gradient_api_url == _GRADIENT_BASE_URL
|
||||
assert llm.gradient_workspace_id == _GRADIENT_WORKSPACE_ID
|
||||
assert llm.model_id == _MODEL_ID
|
||||
|
||||
response = llm("Say foo:")
|
||||
want = "bar"
|
||||
|
||||
assert response == want
|
Loading…
Reference in New Issue