Harrison/prediction guard (#3490)

Co-authored-by: Daniel Whitenack <whitenack.daniel@gmail.com>
This commit is contained in:
Harrison Chase 2023-04-24 22:27:22 -07:00 committed by GitHub
parent 7257f9e015
commit 707741de58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 332 additions and 0 deletions

View File

@ -0,0 +1,56 @@
# Prediction Guard
This page covers how to use the Prediction Guard ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Prediction Guard wrappers.
## Installation and Setup
- Install the Python SDK with `pip install predictionguard`
- Get an Prediction Guard access token (as described [here](https://docs.predictionguard.com/)) and set it as an environment variable (`PREDICTIONGUARD_TOKEN`)
## LLM Wrapper
There exists a Prediction Guard LLM wrapper, which you can access with
```python
from langchain.llms import PredictionGuard
```
You can provide the name of your Prediction Guard "proxy" as an argument when initializing the LLM:
```python
pgllm = PredictionGuard(name="your-text-gen-proxy")
```
Alternatively, you can use Prediction Guard's default proxy for SOTA LLMs:
```python
pgllm = PredictionGuard(name="default-text-gen")
```
You can also provide your access token directly as an argument:
```python
pgllm = PredictionGuard(name="default-text-gen", token="<your access token>")
```
## Example usage
Basic usage of the LLM wrapper:
```python
from langchain.llms import PredictionGuard
pgllm = PredictionGuard(name="default-text-gen")
pgllm("Tell me a joke")
```
Basic LLM Chaining with the Prediction Guard wrapper:
```python
from langchain import PromptTemplate, LLMChain
from langchain.llms import PredictionGuard
template = """Question: {question}
Answer: Let's think step by step."""
prompt = PromptTemplate(template=template, input_variables=["question"])
llm_chain = LLMChain(prompt=prompt, llm=PredictionGuard(name="default-text-gen"), verbose=True)
question = "What NFL team won the Super Bowl in the year Justin Beiber was born?"
llm_chain.predict(question=question)
```

View File

@ -0,0 +1,155 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PredictionGuard\n",
"\n",
"How to use PredictionGuard wrapper"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3RqWPav7AtKL"
},
"outputs": [],
"source": [
"! pip install predictionguard langchain"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "2xe8JEUwA7_y"
},
"outputs": [],
"source": [
"import predictionguard as pg\n",
"from langchain.llms import PredictionGuard"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mesCTyhnJkNS"
},
"source": [
"## Basic LLM usage\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ua7Mw1N4HcER"
},
"outputs": [],
"source": [
"pgllm = PredictionGuard(name=\"default-text-gen\", token=\"<your access token>\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qo2p5flLHxrB"
},
"outputs": [],
"source": [
"pgllm(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v3MzIUItJ8kV"
},
"source": [
"## Chaining"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pPegEZExILrT"
},
"outputs": [],
"source": [
"from langchain import PromptTemplate, LLMChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "suxw62y-J-bg"
},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"llm_chain = LLMChain(prompt=prompt, llm=pgllm, verbose=True)\n",
"\n",
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.predict(question=question)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l2bc26KHKr7n"
},
"outputs": [],
"source": [
"template = \"\"\"Write a {adjective} poem about {subject}.\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"adjective\", \"subject\"])\n",
"llm_chain = LLMChain(prompt=prompt, llm=pgllm, verbose=True)\n",
"\n",
"llm_chain.predict(adjective=\"sad\", subject=\"ducks\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I--eSa2PLGqq"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@ -20,6 +20,7 @@ from langchain.llms.modal import Modal
from langchain.llms.nlpcloud import NLPCloud from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
from langchain.llms.petals import Petals from langchain.llms.petals import Petals
from langchain.llms.predictionguard import PredictionGuard
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
from langchain.llms.replicate import Replicate from langchain.llms.replicate import Replicate
from langchain.llms.rwkv import RWKV from langchain.llms.rwkv import RWKV
@ -59,6 +60,7 @@ __all__ = [
"StochasticAI", "StochasticAI",
"Writer", "Writer",
"RWKV", "RWKV",
"PredictionGuard",
] ]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = { type_to_cls_dict: Dict[str, Type[BaseLLM]] = {

View File

@ -0,0 +1,109 @@
"""Wrapper around Prediction Guard APIs."""
import logging
from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PredictionGuard(LLM):
"""Wrapper around Prediction Guard large language models.
To use, you should have the ``predictionguard`` python package installed, and the
environment variable ``PREDICTIONGUARD_TOKEN`` set with your access token, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
pgllm = PredictionGuard(name="text-gen-proxy-name", token="my-access-token")
"""
client: Any #: :meta private:
name: Optional[str] = "default-text-gen"
"""Proxy name to use."""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
temperature: float = 0.75
"""A non-negative float that tunes the degree of randomness in generation."""
token: Optional[str] = None
stop: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the access token and python package exists in environment."""
token = get_from_dict_or_env(values, "token", "PREDICTIONGUARD_TOKEN")
try:
import predictionguard as pg
values["client"] = pg.Client(token=token)
except ImportError:
raise ValueError(
"Could not import predictionguard python package. "
"Please install it with `pip install predictionguard`."
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"max_tokens": self.max_tokens,
"temperature": self.temperature,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"name": self.name}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "predictionguard"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Prediction Guard's model proxy.
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = pgllm("Tell me a joke.")
"""
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop_sequences"] = self.stop
else:
params["stop_sequences"] = stop
response = self.client.predict(
name=self.name,
data={
"prompt": prompt,
"max_tokens": params["max_tokens"],
"temperature": params["temperature"],
},
)
text = response["text"]
# If stop tokens are provided, Prediction Guard's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop is not None or self.stop is not None:
text = enforce_stop_tokens(text, params["stop_sequences"])
return text

View File

@ -0,0 +1,10 @@
"""Test Prediction Guard API wrapper."""
from langchain.llms.predictionguard import PredictionGuard
def test_predictionguard_call() -> None:
"""Test valid call to prediction guard."""
llm = PredictionGuard(name="default-text-gen")
output = llm("Say foo:")
assert isinstance(output, str)