From 707741de58c11afe92ccd4e6542b22d0070255b3 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 24 Apr 2023 22:27:22 -0700 Subject: [PATCH] Harrison/prediction guard (#3490) Co-authored-by: Daniel Whitenack --- docs/ecosystem/predictionguard.md | 56 +++++++ .../llms/integrations/predictionguard.ipynb | 155 ++++++++++++++++++ langchain/llms/__init__.py | 2 + langchain/llms/predictionguard.py | 109 ++++++++++++ .../llms/test_predictionguard.py | 10 ++ 5 files changed, 332 insertions(+) create mode 100644 docs/ecosystem/predictionguard.md create mode 100644 docs/modules/models/llms/integrations/predictionguard.ipynb create mode 100644 langchain/llms/predictionguard.py create mode 100644 tests/integration_tests/llms/test_predictionguard.py diff --git a/docs/ecosystem/predictionguard.md b/docs/ecosystem/predictionguard.md new file mode 100644 index 00000000..1fffb550 --- /dev/null +++ b/docs/ecosystem/predictionguard.md @@ -0,0 +1,56 @@ +# Prediction Guard + +This page covers how to use the Prediction Guard ecosystem within LangChain. +It is broken into two parts: installation and setup, and then references to specific Prediction Guard wrappers. + +## Installation and Setup +- Install the Python SDK with `pip install predictionguard` +- Get an Prediction Guard access token (as described [here](https://docs.predictionguard.com/)) and set it as an environment variable (`PREDICTIONGUARD_TOKEN`) + +## LLM Wrapper + +There exists a Prediction Guard LLM wrapper, which you can access with +```python +from langchain.llms import PredictionGuard +``` + +You can provide the name of your Prediction Guard "proxy" as an argument when initializing the LLM: +```python +pgllm = PredictionGuard(name="your-text-gen-proxy") +``` + +Alternatively, you can use Prediction Guard's default proxy for SOTA LLMs: +```python +pgllm = PredictionGuard(name="default-text-gen") +``` + +You can also provide your access token directly as an argument: +```python +pgllm = PredictionGuard(name="default-text-gen", token="") +``` + +## Example usage + +Basic usage of the LLM wrapper: +```python +from langchain.llms import PredictionGuard + +pgllm = PredictionGuard(name="default-text-gen") +pgllm("Tell me a joke") +``` + +Basic LLM Chaining with the Prediction Guard wrapper: +```python +from langchain import PromptTemplate, LLMChain +from langchain.llms import PredictionGuard + +template = """Question: {question} + +Answer: Let's think step by step.""" +prompt = PromptTemplate(template=template, input_variables=["question"]) +llm_chain = LLMChain(prompt=prompt, llm=PredictionGuard(name="default-text-gen"), verbose=True) + +question = "What NFL team won the Super Bowl in the year Justin Beiber was born?" + +llm_chain.predict(question=question) +``` \ No newline at end of file diff --git a/docs/modules/models/llms/integrations/predictionguard.ipynb b/docs/modules/models/llms/integrations/predictionguard.ipynb new file mode 100644 index 00000000..78fd8390 --- /dev/null +++ b/docs/modules/models/llms/integrations/predictionguard.ipynb @@ -0,0 +1,155 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PredictionGuard\n", + "\n", + "How to use PredictionGuard wrapper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3RqWPav7AtKL" + }, + "outputs": [], + "source": [ + "! pip install predictionguard langchain" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "2xe8JEUwA7_y" + }, + "outputs": [], + "source": [ + "import predictionguard as pg\n", + "from langchain.llms import PredictionGuard" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mesCTyhnJkNS" + }, + "source": [ + "## Basic LLM usage\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ua7Mw1N4HcER" + }, + "outputs": [], + "source": [ + "pgllm = PredictionGuard(name=\"default-text-gen\", token=\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Qo2p5flLHxrB" + }, + "outputs": [], + "source": [ + "pgllm(\"Tell me a joke\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v3MzIUItJ8kV" + }, + "source": [ + "## Chaining" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pPegEZExILrT" + }, + "outputs": [], + "source": [ + "from langchain import PromptTemplate, LLMChain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "suxw62y-J-bg" + }, + "outputs": [], + "source": [ + "template = \"\"\"Question: {question}\n", + "\n", + "Answer: Let's think step by step.\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n", + "llm_chain = LLMChain(prompt=prompt, llm=pgllm, verbose=True)\n", + "\n", + "question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n", + "\n", + "llm_chain.predict(question=question)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "l2bc26KHKr7n" + }, + "outputs": [], + "source": [ + "template = \"\"\"Write a {adjective} poem about {subject}.\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"adjective\", \"subject\"])\n", + "llm_chain = LLMChain(prompt=prompt, llm=pgllm, verbose=True)\n", + "\n", + "llm_chain.predict(adjective=\"sad\", subject=\"ducks\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "I--eSa2PLGqq" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index fcb98f4f..439d934e 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -20,6 +20,7 @@ from langchain.llms.modal import Modal from langchain.llms.nlpcloud import NLPCloud from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat from langchain.llms.petals import Petals +from langchain.llms.predictionguard import PredictionGuard from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat from langchain.llms.replicate import Replicate from langchain.llms.rwkv import RWKV @@ -59,6 +60,7 @@ __all__ = [ "StochasticAI", "Writer", "RWKV", + "PredictionGuard", ] type_to_cls_dict: Dict[str, Type[BaseLLM]] = { diff --git a/langchain/llms/predictionguard.py b/langchain/llms/predictionguard.py new file mode 100644 index 00000000..c5ba6165 --- /dev/null +++ b/langchain/llms/predictionguard.py @@ -0,0 +1,109 @@ +"""Wrapper around Prediction Guard APIs.""" +import logging +from typing import Any, Dict, List, Optional + +from pydantic import Extra, root_validator + +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class PredictionGuard(LLM): + """Wrapper around Prediction Guard large language models. + To use, you should have the ``predictionguard`` python package installed, and the + environment variable ``PREDICTIONGUARD_TOKEN`` set with your access token, or pass + it as a named parameter to the constructor. + Example: + .. code-block:: python + pgllm = PredictionGuard(name="text-gen-proxy-name", token="my-access-token") + """ + + client: Any #: :meta private: + name: Optional[str] = "default-text-gen" + """Proxy name to use.""" + + max_tokens: int = 256 + """Denotes the number of tokens to predict per generation.""" + + temperature: float = 0.75 + """A non-negative float that tunes the degree of randomness in generation.""" + + token: Optional[str] = None + + stop: Optional[List[str]] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that the access token and python package exists in environment.""" + token = get_from_dict_or_env(values, "token", "PREDICTIONGUARD_TOKEN") + try: + import predictionguard as pg + + values["client"] = pg.Client(token=token) + except ImportError: + raise ValueError( + "Could not import predictionguard python package. " + "Please install it with `pip install predictionguard`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Cohere API.""" + return { + "max_tokens": self.max_tokens, + "temperature": self.temperature, + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"name": self.name}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "predictionguard" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to Prediction Guard's model proxy. + Args: + prompt: The prompt to pass into the model. + Returns: + The string generated by the model. + Example: + .. code-block:: python + response = pgllm("Tell me a joke.") + """ + params = self._default_params + if self.stop is not None and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop is not None: + params["stop_sequences"] = self.stop + else: + params["stop_sequences"] = stop + + response = self.client.predict( + name=self.name, + data={ + "prompt": prompt, + "max_tokens": params["max_tokens"], + "temperature": params["temperature"], + }, + ) + text = response["text"] + + # If stop tokens are provided, Prediction Guard's endpoint returns them. + # In order to make this consistent with other endpoints, we strip them. + if stop is not None or self.stop is not None: + text = enforce_stop_tokens(text, params["stop_sequences"]) + + return text diff --git a/tests/integration_tests/llms/test_predictionguard.py b/tests/integration_tests/llms/test_predictionguard.py new file mode 100644 index 00000000..0100fba9 --- /dev/null +++ b/tests/integration_tests/llms/test_predictionguard.py @@ -0,0 +1,10 @@ +"""Test Prediction Guard API wrapper.""" + +from langchain.llms.predictionguard import PredictionGuard + + +def test_predictionguard_call() -> None: + """Test valid call to prediction guard.""" + llm = PredictionGuard(name="default-text-gen") + output = llm("Say foo:") + assert isinstance(output, str)