Harrison/predibase (#8046)

Co-authored-by: Abhay Malik <32989166+Abhay-765@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-07-20 19:26:50 -07:00 committed by GitHub
parent 56c6ab1715
commit f99f497b2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 292 additions and 0 deletions

View File

@ -0,0 +1,24 @@
# Predibase
Learn how to use LangChain with models on Predibase.
## Setup
- Create a [Predibase](hhttps://predibase.com/) account and [API key](https://docs.predibase.com/sdk-guide/intro).
- Install the Predibase Python client with `pip install predibase`
- Use your API key to authenticate
### LLM
Predibase integrates with LangChain by implementing LLM module. You can see a short example below or a full notebook under LLM > Integrations > Predibase.
```python
import os
os.environ["PREDIBASE_API_TOKEN"] = "{PREDIBASE_API_TOKEN}"
from langchain.llms import Predibase
model = Predibase(model = 'vicuna-13b', predibase_api_key=os.environ.get('PREDIBASE_API_TOKEN'))
response = model("Can you recommend me a nice dry wine?")
print(response)
```

View File

@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predibase\n",
"\n",
"[Predibase](https://predibase.com/) allows you to train, finetune, and deploy any ML model—from linear regression to large language model. \n",
"\n",
"This example demonstrates using Langchain with models deployed on Predibase"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"\n",
"To run this notebook, you'll need a [Predibase account](https://predibase.com/free-trial/?utm_source=langchain) and an [API key](https://docs.predibase.com/sdk-guide/intro).\n",
"\n",
"You'll also need to install the Predibase Python package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install predibase\n",
"import os\n",
"\n",
"os.environ[\"PREDIBASE_API_TOKEN\"] = \"{PREDIBASE_API_TOKEN}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initial Call"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = model(\"Can you recommend me a nice dry wine?\")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chain Call Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = Predibase(\n",
" model=\"vicuna-13b\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SequentialChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is an LLMChain to write a synopsis given a title of a play.\n",
"template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n",
"\n",
"Title: {title}\n",
"Playwright: This is a synopsis for the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
"synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is an LLMChain to write a review of a play given a synopsis.\n",
"template = \"\"\"You are a play critic from the New York Times. Given the synopsis of play, it is your job to write a review for that play.\n",
"\n",
"Play Synopsis:\n",
"{synopsis}\n",
"Review from a New York Times play critic of the above play:\"\"\"\n",
"prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
"review_chain = LLMChain(llm=llm, prompt=prompt_template)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# This is the overall chain where we run these two chains in sequence.\n",
"from langchain.chains import SimpleSequentialChain\n",
"\n",
"overall_chain = SimpleSequentialChain(\n",
" chains=[synopsis_chain, review_chain], verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"review = overall_chain.run(\"Tragedy at sunset on the beach\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fine-tuned LLM (Use your own fine-tuned LLM from Predibase)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Predibase\n",
"\n",
"model = Predibase(\n",
" model=\"my-finetuned-LLM\", predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\")\n",
")\n",
"# replace my-finetuned-LLM with the name of your model in Predibase"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# response = model(\"Can you help categorize the following emails into positive, negative, and neutral?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.9 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.9"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -43,6 +43,7 @@ from langchain.llms.openllm import OpenLLM
from langchain.llms.openlm import OpenLM
from langchain.llms.petals import Petals
from langchain.llms.pipelineai import PipelineAI
from langchain.llms.predibase import Predibase
from langchain.llms.predictionguard import PredictionGuard
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
from langchain.llms.replicate import Replicate
@ -100,6 +101,7 @@ __all__ = [
"OpenLM",
"Petals",
"PipelineAI",
"Predibase",
"PredictionGuard",
"PromptLayerOpenAI",
"PromptLayerOpenAIChat",
@ -156,6 +158,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"openlm": OpenLM,
"petals": Petals,
"pipelineai": PipelineAI,
"predibase": Predibase,
"replicate": Replicate,
"rwkv": RWKV,
"sagemaker_endpoint": SagemakerEndpoint,

View File

@ -0,0 +1,51 @@
from typing import Any, Dict, List, Mapping, Optional
from pydantic import Field
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
class Predibase(LLM):
"""Use your Predibase models with Langchain.
To use, you should have the ``predibase`` python package installed,
and have your Predibase API key.
"""
model: str
predibase_api_key: str
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
return "predibase"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> str:
try:
from predibase import PredibaseClient
pc = PredibaseClient(token=self.predibase_api_key)
except ImportError as e:
raise ImportError(
"Could not import Predibase Python package. "
"Please install it with `pip install predibase`."
) from e
except ValueError as e:
raise ValueError("Your API key is not correct. Please try again") from e
# load model and version
results = pc.prompt(prompt, model_name=self.model)
return results[0].response
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"model_kwargs": self.model_kwargs},
}