langchain/docs/extras/integrations/llms/gradient.ipynb

217 lines
6.0 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gradient\n",
"\n",
"`Gradient` allows to fine tune and get completions on LLMs with a simple web API.\n",
"\n",
"This notebook goes over how to use Langchain with [Gradient](https://gradient.ai/).\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import requests\n",
"from langchain.llms import GradientLLM\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set the Environment API Key\n",
"Make sure to get your API key from Gradient AI. You are given $10 in free credits to test and fine-tune different models."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from getpass import getpass\n",
"\n",
"\n",
"if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\",None):\n",
" # Access token under https://auth.gradient.ai/select-workspace\n",
" os.environ[\"GRADIENT_ACCESS_TOKEN\"] = getpass(\"gradient.ai access token:\")\n",
"if not os.environ.get(\"GRADIENT_WORKSPACE_ID\",None):\n",
" # `ID` listed in `$ gradient workspace list`\n",
" # also displayed after login at at https://auth.gradient.ai/select-workspace\n",
" os.environ[\"GRADIENT_WORKSPACE_ID\"] = getpass(\"gradient.ai workspace id:\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Optional: Validate your Enviroment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Credentials valid.\n",
"Possible values for `model_id` are:\n",
" {'models': [{'id': '99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model', 'name': 'bloom-560m', 'slug': 'bloom-560m', 'type': 'baseModel'}, {'id': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'name': 'llama2-7b-chat', 'slug': 'llama2-7b-chat', 'type': 'baseModel'}, {'id': 'cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model', 'name': 'nous-hermes2', 'slug': 'nous-hermes2', 'type': 'baseModel'}, {'baseModelId': 'f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model', 'id': 'bb7b9865-0ce3-41a8-8e2b-5cbcbe1262eb_model_adapter', 'name': 'optical-transmitting-sensor', 'type': 'modelAdapter'}]}\n"
]
}
],
"source": [
"import requests\n",
"\n",
"resp = requests.get(f'https://api.gradient.ai/api/models', headers={\n",
" \"authorization\": f\"Bearer {os.environ['GRADIENT_ACCESS_TOKEN']}\",\n",
" \"x-gradient-workspace-id\": f\"{os.environ['GRADIENT_WORKSPACE_ID']}\",\n",
" },\n",
" )\n",
"if resp.status_code == 200:\n",
" models = resp.json()\n",
" print(\"Credentials valid.\\nPossible values for `model_id` are:\\n\", models)\n",
"else:\n",
" print(\"Error when listing models. Are your credentials valid?\", resp.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create the Gradient instance\n",
"You can specify different parameters such as the model name, max tokens generated, temperature, etc."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"llm = GradientLLM(\n",
" # `ID` listed in `$ gradient model list`\n",
" model_id=\"99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\",\n",
" # # optional: set new credentials, they default to environment variables\n",
" # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
" # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a Prompt Template\n",
"We will create a prompt template for Question and Answer."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initiate the LLMChain"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run the LLMChain\n",
"Provide a question and run the LLMChain."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' The first team to win the Super Bowl was the New England Patriots. The Patriots won the'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"What NFL team won the Super Bowl in 1994?\"\n",
"\n",
"llm_chain.run(\n",
" question=question\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}