From d85b04be7f49daa59156d6897f542cf25c3d76fb Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Sun, 14 May 2023 22:40:03 +0000 Subject: [PATCH] Add RELLM and JSONFormer experimental LLM decoding (#4185) [RELLM](https://github.com/r2d4/rellm) is a library that wraps local HuggingFace pipeline models for structured decoding. RELLM works by generating tokens one at a time. At each step, it masks tokens that don't conform to the provided partial regular expression. [JSONFormer](https://github.com/1rgs/jsonformer) is a bit different, where it sequentially adds the keys then decodes each value directly --- .../jsonformer_experimental.ipynb | 280 ++++++++++++++++++ .../integrations/rellm_experimental.ipynb | 208 +++++++++++++ langchain/experimental/llms/__init__.py | 6 + .../experimental/llms/jsonformer_decoder.py | 60 ++++ langchain/experimental/llms/rellm_decoder.py | 67 +++++ 5 files changed, 621 insertions(+) create mode 100644 docs/modules/models/llms/integrations/jsonformer_experimental.ipynb create mode 100644 docs/modules/models/llms/integrations/rellm_experimental.ipynb create mode 100644 langchain/experimental/llms/__init__.py create mode 100644 langchain/experimental/llms/jsonformer_decoder.py create mode 100644 langchain/experimental/llms/rellm_decoder.py diff --git a/docs/modules/models/llms/integrations/jsonformer_experimental.ipynb b/docs/modules/models/llms/integrations/jsonformer_experimental.ipynb new file mode 100644 index 00000000..8cff4ba5 --- /dev/null +++ b/docs/modules/models/llms/integrations/jsonformer_experimental.ipynb @@ -0,0 +1,280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7864c-93e6-4eb4-a923-b80d2ae4377d", + "metadata": {}, + "source": [ + "# Structured Decoding with JSONFormer\n", + "\n", + "[JSONFormer](https://github.com/1rgs/jsonformer) is a library that wraps local HuggingFace pipeline models for structured decoding of a subset of the JSON Schema.\n", + "\n", + "It works by filling in the structure tokens and then sampling the content tokens from the model.\n", + "\n", + "**Warning - this module is still experimental**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1617e327-d9a2-4ab6-aa9f-30a3167a3393", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install --upgrade jsonformer > /dev/null" + ] + }, + { + "cell_type": "markdown", + "id": "66bd89f1-8daa-433d-bb8f-5b0b3ae34b00", + "metadata": {}, + "source": [ + "### HuggingFace Baseline\n", + "\n", + "First, let's establish a qualitative baseline by checking the output of the model without structured decoding." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d4d616ae-4d11-425f-b06c-c706d0386c68", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "logging.basicConfig(level=logging.ERROR)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1bdc7b60-6ffb-4099-9fa6-13efdfc45b04", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import Optional\n", + "from langchain.tools import tool\n", + "import os\n", + "import json\n", + "import requests\n", + "\n", + "HF_TOKEN = os.environ.get(\"HUGGINGFACE_API_KEY\")\n", + "\n", + "@tool\n", + "def ask_star_coder(query: str, \n", + " temperature: float = 1.0,\n", + " max_new_tokens: float = 250):\n", + " \"\"\"Query the BigCode StarCoder model about coding questions.\"\"\"\n", + " url = \"https://api-inference.huggingface.co/models/bigcode/starcoder\"\n", + " headers = {\n", + " \"Authorization\": f\"Bearer {HF_TOKEN}\",\n", + " \"content-type\": \"application/json\"\n", + " }\n", + " payload = {\n", + " \"inputs\": f\"{query}\\n\\nAnswer:\",\n", + " \"temperature\": temperature,\n", + " \"max_new_tokens\": int(max_new_tokens),\n", + " }\n", + " response = requests.post(url, headers=headers, data=json.dumps(payload))\n", + " response.raise_for_status()\n", + " return json.loads(response.content.decode(\"utf-8\"))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d5522977-51e8-40eb-9403-8ab70b14908e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prompt = \"\"\"You must respond using JSON format, with a single action and single action input.\n", + "You may 'ask_star_coder' for help on coding problems.\n", + "\n", + "{arg_schema}\n", + "\n", + "EXAMPLES\n", + "----\n", + "Human: \"So what's all this about a GIL?\"\n", + "AI Assistant:{{\n", + " \"action\": \"ask_star_coder\",\n", + " \"action_input\": {{\"query\": \"What is a GIL?\", \"temperature\": 0.0, \"max_new_tokens\": 100}}\"\n", + "}}\n", + "Observation: \"The GIL is python's Global Interpreter Lock\"\n", + "Human: \"Could you please write a calculator program in LISP?\"\n", + "AI Assistant:{{\n", + " \"action\": \"ask_star_coder\",\n", + " \"action_input\": {{\"query\": \"Write a calculator program in LISP\", \"temperature\": 0.0, \"max_new_tokens\": 250}}\n", + "}}\n", + "Observation: \"(defun add (x y) (+ x y))\\n(defun sub (x y) (- x y ))\"\n", + "Human: \"What's the difference between an SVM and an LLM?\"\n", + "AI Assistant:{{\n", + " \"action\": \"ask_star_coder\",\n", + " \"action_input\": {{\"query\": \"What's the difference between SGD and an SVM?\", \"temperature\": 1.0, \"max_new_tokens\": 250}}\n", + "}}\n", + "Observation: \"SGD stands for stochastic gradient descent, while an SVM is a Support Vector Machine.\"\n", + "\n", + "BEGIN! Answer the Human's question as best as you are able.\n", + "------\n", + "Human: 'What's the difference between an iterator and an iterable?'\n", + "AI Assistant:\"\"\".format(arg_schema=ask_star_coder.args)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9148e4b8-d370-4c05-a873-c121b65057b5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 'What's the difference between an iterator and an iterable?'\n", + "\n" + ] + } + ], + "source": [ + "from transformers import pipeline\n", + "from langchain.llms import HuggingFacePipeline\n", + "\n", + "hf_model = pipeline(\"text-generation\", model=\"cerebras/Cerebras-GPT-590M\", max_new_tokens=200)\n", + "\n", + "original_model = HuggingFacePipeline(pipeline=hf_model)\n", + "\n", + "generated = original_model.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n", + "print(generated)" + ] + }, + { + "cell_type": "markdown", + "id": "b6e7b9cf-8ce5-4f87-b4bf-100321ad2dd1", + "metadata": {}, + "source": [ + "***That's not so impressive, is it? It didn't follow the JSON format at all! Let's try with the structured decoder.***" + ] + }, + { + "cell_type": "markdown", + "id": "96115154-a90a-46cb-9759-573860fc9b79", + "metadata": {}, + "source": [ + "## JSONFormer LLM Wrapper\n", + "\n", + "Let's try that again, now providing a the Action input's JSON Schema to the model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "30066ee7-9a92-4ae8-91bf-3262bf3c70c2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "decoder_schema = {\n", + " \"title\": \"Decoding Schema\",\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"action\": {\"type\": \"string\", \"default\": ask_star_coder.name},\n", + " \"action_input\": {\n", + " \"type\": \"object\",\n", + " \"properties\": ask_star_coder.args,\n", + " }\n", + " }\n", + "} " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0f7447fe-22a9-47db-85b9-7adf0f19307d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.experimental.llms import JsonFormer\n", + "json_former = JsonFormer(json_schema=decoder_schema, pipeline=hf_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d865e049-a5c3-4648-92db-8b912b7474ee", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"action\": \"ask_star_coder\", \"action_input\": {\"query\": \"What's the difference between an iterator and an iter\", \"temperature\": 0.0, \"max_new_tokens\": 50.0}}\n" + ] + } + ], + "source": [ + "results = json_former.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n", + "print(results)" + ] + }, + { + "cell_type": "markdown", + "id": "32077d74-0605-4138-9a10-0ce36637040d", + "metadata": { + "tags": [] + }, + "source": [ + "**Voila! Free of parsing errors.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da63ce31-de79-4462-a1a9-b726b698c5ba", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/models/llms/integrations/rellm_experimental.ipynb b/docs/modules/models/llms/integrations/rellm_experimental.ipynb new file mode 100644 index 00000000..395645b5 --- /dev/null +++ b/docs/modules/models/llms/integrations/rellm_experimental.ipynb @@ -0,0 +1,208 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7864c-93e6-4eb4-a923-b80d2ae4377d", + "metadata": {}, + "source": [ + "# Structured Decoding with RELLM\n", + "\n", + "[RELLM](https://github.com/r2d4/rellm) is a library that wraps local HuggingFace pipeline models for structured decoding.\n", + "\n", + "It works by generating tokens one at a time. At each step, it masks tokens that don't conform to the provided partial regular expression.\n", + "\n", + "\n", + "**Warning - this module is still experimental**" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1617e327-d9a2-4ab6-aa9f-30a3167a3393", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install rellm > /dev/null" + ] + }, + { + "cell_type": "markdown", + "id": "66bd89f1-8daa-433d-bb8f-5b0b3ae34b00", + "metadata": {}, + "source": [ + "### HuggingFace Baseline\n", + "\n", + "First, let's establish a qualitative baseline by checking the output of the model without structured decoding." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d4d616ae-4d11-425f-b06c-c706d0386c68", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import logging\n", + "logging.basicConfig(level=logging.ERROR)\n", + "prompt = \"\"\"Human: \"What's the capital of the United States?\"\n", + "AI Assistant:{\n", + " \"action\": \"Final Answer\",\n", + " \"action_input\": \"The capital of the United States is Washington D.C.\"\n", + "}\n", + "Human: \"What's the capital of Pennsylvania?\"\n", + "AI Assistant:{\n", + " \"action\": \"Final Answer\",\n", + " \"action_input\": \"The capital of Pennsylvania is Harrisburg.\"\n", + "}\n", + "Human: \"What 2 + 5?\"\n", + "AI Assistant:{\n", + " \"action\": \"Final Answer\",\n", + " \"action_input\": \"2 + 5 = 7.\"\n", + "}\n", + "Human: 'What's the capital of Maryland?'\n", + "AI Assistant:\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9148e4b8-d370-4c05-a873-c121b65057b5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "generations=[[Generation(text=' \"What\\'s the capital of Maryland?\"\\n', generation_info=None)]] llm_output=None\n" + ] + } + ], + "source": [ + "from transformers import pipeline\n", + "from langchain.llms import HuggingFacePipeline\n", + "\n", + "hf_model = pipeline(\"text-generation\", model=\"cerebras/Cerebras-GPT-590M\", max_new_tokens=200)\n", + "\n", + "original_model = HuggingFacePipeline(pipeline=hf_model)\n", + "\n", + "generated = original_model.generate([prompt], stop=[\"Human:\"])\n", + "print(generated)" + ] + }, + { + "cell_type": "markdown", + "id": "b6e7b9cf-8ce5-4f87-b4bf-100321ad2dd1", + "metadata": {}, + "source": [ + "***That's not so impressive, is it? It didn't answer the question and it didn't follow the JSON format at all! Let's try with the structured decoder.***" + ] + }, + { + "cell_type": "markdown", + "id": "96115154-a90a-46cb-9759-573860fc9b79", + "metadata": {}, + "source": [ + "## RELLM LLM Wrapper\n", + "\n", + "Let's try that again, now providing a regex to match the JSON structured format." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "65c12e2a-bd7f-4cf0-8ef8-92cfa31c92ef", + "metadata": {}, + "outputs": [], + "source": [ + "import regex # Note this is the regex library NOT python's re stdlib module\n", + "\n", + "# We'll choose a regex that matches to a structured json string that looks like:\n", + "# {\n", + "# \"action\": \"Final Answer\",\n", + "# \"action_input\": string or dict\n", + "# }\n", + "pattern = regex.compile(r'\\{\\s*\"action\":\\s*\"Final Answer\",\\s*\"action_input\":\\s*(\\{.*\\}|\"[^\"]*\")\\s*\\}\\nHuman:')" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "de85b1f8-b405-4291-b6d0-4b2c56e77ad6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"action\": \"Final Answer\",\n", + " \"action_input\": \"The capital of Maryland is Baltimore.\"\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "from langchain.experimental.llms import RELLM\n", + "\n", + "model = RELLM(pipeline=hf_model, regex=pattern, max_new_tokens=200)\n", + "\n", + "generated = model.predict(prompt, stop=[\"Human:\"])\n", + "print(generated)" + ] + }, + { + "cell_type": "markdown", + "id": "32077d74-0605-4138-9a10-0ce36637040d", + "metadata": { + "tags": [] + }, + "source": [ + "**Voila! Free of parsing errors.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bd208a1-779c-4c47-97d9-9115d15d441f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/experimental/llms/__init__.py b/langchain/experimental/llms/__init__.py new file mode 100644 index 00000000..bac4ff70 --- /dev/null +++ b/langchain/experimental/llms/__init__.py @@ -0,0 +1,6 @@ +"""Experimental LLM wrappers.""" + +from langchain.experimental.llms.jsonformer_decoder import JsonFormer +from langchain.experimental.llms.rellm_decoder import RELLM + +__all__ = ["RELLM", "JsonFormer"] diff --git a/langchain/experimental/llms/jsonformer_decoder.py b/langchain/experimental/llms/jsonformer_decoder.py new file mode 100644 index 00000000..f0305f3f --- /dev/null +++ b/langchain/experimental/llms/jsonformer_decoder.py @@ -0,0 +1,60 @@ +"""Experimental implementation of jsonformer wrapped LLM.""" +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, List, Optional, cast + +from pydantic import Field, root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.huggingface_pipeline import HuggingFacePipeline + +if TYPE_CHECKING: + import jsonformer + + +def import_jsonformer() -> jsonformer: + """Lazily import jsonformer.""" + try: + import jsonformer + except ImportError: + raise ValueError( + "Could not import jsonformer python package. " + "Please install it with `pip install jsonformer`." + ) + return jsonformer + + +class JsonFormer(HuggingFacePipeline): + json_schema: dict = Field(..., description="The JSON Schema to complete.") + max_new_tokens: int = Field( + default=200, description="Maximum number of new tokens to generate." + ) + debug: bool = Field(default=False, description="Debug mode.") + + @root_validator + def check_jsonformer_installation(cls, values: dict) -> dict: + import_jsonformer() + return values + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + jsonformer = import_jsonformer() + from transformers import Text2TextGenerationPipeline + + pipeline = cast(Text2TextGenerationPipeline, self.pipeline) + + model = jsonformer.Jsonformer( + model=pipeline.model, + tokenizer=pipeline.tokenizer, + json_schema=self.json_schema, + prompt=prompt, + max_number_tokens=self.max_new_tokens, + debug=self.debug, + ) + text = model() + return json.dumps(text) diff --git a/langchain/experimental/llms/rellm_decoder.py b/langchain/experimental/llms/rellm_decoder.py new file mode 100644 index 00000000..8449b775 --- /dev/null +++ b/langchain/experimental/llms/rellm_decoder.py @@ -0,0 +1,67 @@ +"""Experimental implementation of RELLM wrapped LLM.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional, cast + +from pydantic import Field, root_validator + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.huggingface_pipeline import HuggingFacePipeline +from langchain.llms.utils import enforce_stop_tokens + +if TYPE_CHECKING: + import rellm + from regex import Pattern as RegexPattern +else: + try: + from regex import Pattern as RegexPattern + except ImportError: + pass + + +def import_rellm() -> rellm: + """Lazily import rellm.""" + try: + import rellm + except ImportError: + raise ValueError( + "Could not import rellm python package. " + "Please install it with `pip install rellm`." + ) + return rellm + + +class RELLM(HuggingFacePipeline): + regex: RegexPattern = Field(..., description="The structured format to complete.") + max_new_tokens: int = Field( + default=200, description="Maximum number of new tokens to generate." + ) + + @root_validator + def check_rellm_installation(cls, values: dict) -> dict: + import_rellm() + return values + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + rellm = import_rellm() + from transformers import Text2TextGenerationPipeline + + pipeline = cast(Text2TextGenerationPipeline, self.pipeline) + + text = rellm.complete_re( + prompt, + self.regex, + tokenizer=pipeline.tokenizer, + model=pipeline.model, + max_new_tokens=self.max_new_tokens, + ) + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + return text