{ "cells": [ { "cell_type": "markdown", "id": "fdd7864c-93e6-4eb4-a923-b80d2ae4377d", "metadata": {}, "source": [ "# Structured Decoding with JSONFormer\n", "\n", "[JSONFormer](https://github.com/1rgs/jsonformer) is a library that wraps local HuggingFace pipeline models for structured decoding of a subset of the JSON Schema.\n", "\n", "It works by filling in the structure tokens and then sampling the content tokens from the model.\n", "\n", "**Warning - this module is still experimental**" ] }, { "cell_type": "code", "execution_count": 1, "id": "1617e327-d9a2-4ab6-aa9f-30a3167a3393", "metadata": { "tags": [] }, "outputs": [], "source": [ "!pip install --upgrade jsonformer > /dev/null" ] }, { "cell_type": "markdown", "id": "66bd89f1-8daa-433d-bb8f-5b0b3ae34b00", "metadata": {}, "source": [ "### HuggingFace Baseline\n", "\n", "First, let's establish a qualitative baseline by checking the output of the model without structured decoding." ] }, { "cell_type": "code", "execution_count": 2, "id": "d4d616ae-4d11-425f-b06c-c706d0386c68", "metadata": { "tags": [] }, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(level=logging.ERROR)" ] }, { "cell_type": "code", "execution_count": 3, "id": "1bdc7b60-6ffb-4099-9fa6-13efdfc45b04", "metadata": { "tags": [] }, "outputs": [], "source": [ "from typing import Optional\n", "from langchain.tools import tool\n", "import os\n", "import json\n", "import requests\n", "\n", "HF_TOKEN = os.environ.get(\"HUGGINGFACE_API_KEY\")\n", "\n", "@tool\n", "def ask_star_coder(query: str, \n", " temperature: float = 1.0,\n", " max_new_tokens: float = 250):\n", " \"\"\"Query the BigCode StarCoder model about coding questions.\"\"\"\n", " url = \"https://api-inference.huggingface.co/models/bigcode/starcoder\"\n", " headers = {\n", " \"Authorization\": f\"Bearer {HF_TOKEN}\",\n", " \"content-type\": \"application/json\"\n", " }\n", " payload = {\n", " \"inputs\": f\"{query}\\n\\nAnswer:\",\n", " \"temperature\": temperature,\n", " \"max_new_tokens\": int(max_new_tokens),\n", " }\n", " response = requests.post(url, headers=headers, data=json.dumps(payload))\n", " response.raise_for_status()\n", " return json.loads(response.content.decode(\"utf-8\"))\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "d5522977-51e8-40eb-9403-8ab70b14908e", "metadata": { "tags": [] }, "outputs": [], "source": [ "prompt = \"\"\"You must respond using JSON format, with a single action and single action input.\n", "You may 'ask_star_coder' for help on coding problems.\n", "\n", "{arg_schema}\n", "\n", "EXAMPLES\n", "----\n", "Human: \"So what's all this about a GIL?\"\n", "AI Assistant:{{\n", " \"action\": \"ask_star_coder\",\n", " \"action_input\": {{\"query\": \"What is a GIL?\", \"temperature\": 0.0, \"max_new_tokens\": 100}}\"\n", "}}\n", "Observation: \"The GIL is python's Global Interpreter Lock\"\n", "Human: \"Could you please write a calculator program in LISP?\"\n", "AI Assistant:{{\n", " \"action\": \"ask_star_coder\",\n", " \"action_input\": {{\"query\": \"Write a calculator program in LISP\", \"temperature\": 0.0, \"max_new_tokens\": 250}}\n", "}}\n", "Observation: \"(defun add (x y) (+ x y))\\n(defun sub (x y) (- x y ))\"\n", "Human: \"What's the difference between an SVM and an LLM?\"\n", "AI Assistant:{{\n", " \"action\": \"ask_star_coder\",\n", " \"action_input\": {{\"query\": \"What's the difference between SGD and an SVM?\", \"temperature\": 1.0, \"max_new_tokens\": 250}}\n", "}}\n", "Observation: \"SGD stands for stochastic gradient descent, while an SVM is a Support Vector Machine.\"\n", "\n", "BEGIN! Answer the Human's question as best as you are able.\n", "------\n", "Human: 'What's the difference between an iterator and an iterable?'\n", "AI Assistant:\"\"\".format(arg_schema=ask_star_coder.args)" ] }, { "cell_type": "code", "execution_count": 5, "id": "9148e4b8-d370-4c05-a873-c121b65057b5", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 'What's the difference between an iterator and an iterable?'\n", "\n" ] } ], "source": [ "from transformers import pipeline\n", "from langchain.llms import HuggingFacePipeline\n", "\n", "hf_model = pipeline(\"text-generation\", model=\"cerebras/Cerebras-GPT-590M\", max_new_tokens=200)\n", "\n", "original_model = HuggingFacePipeline(pipeline=hf_model)\n", "\n", "generated = original_model.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n", "print(generated)" ] }, { "cell_type": "markdown", "id": "b6e7b9cf-8ce5-4f87-b4bf-100321ad2dd1", "metadata": {}, "source": [ "***That's not so impressive, is it? It didn't follow the JSON format at all! Let's try with the structured decoder.***" ] }, { "cell_type": "markdown", "id": "96115154-a90a-46cb-9759-573860fc9b79", "metadata": {}, "source": [ "## JSONFormer LLM Wrapper\n", "\n", "Let's try that again, now providing a the Action input's JSON Schema to the model." ] }, { "cell_type": "code", "execution_count": 6, "id": "30066ee7-9a92-4ae8-91bf-3262bf3c70c2", "metadata": { "tags": [] }, "outputs": [], "source": [ "decoder_schema = {\n", " \"title\": \"Decoding Schema\",\n", " \"type\": \"object\",\n", " \"properties\": {\n", " \"action\": {\"type\": \"string\", \"default\": ask_star_coder.name},\n", " \"action_input\": {\n", " \"type\": \"object\",\n", " \"properties\": ask_star_coder.args,\n", " }\n", " }\n", "} " ] }, { "cell_type": "code", "execution_count": 7, "id": "0f7447fe-22a9-47db-85b9-7adf0f19307d", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.experimental.llms import JsonFormer\n", "json_former = JsonFormer(json_schema=decoder_schema, pipeline=hf_model)" ] }, { "cell_type": "code", "execution_count": 10, "id": "d865e049-a5c3-4648-92db-8b912b7474ee", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\"action\": \"ask_star_coder\", \"action_input\": {\"query\": \"What's the difference between an iterator and an iter\", \"temperature\": 0.0, \"max_new_tokens\": 50.0}}\n" ] } ], "source": [ "results = json_former.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n", "print(results)" ] }, { "cell_type": "markdown", "id": "32077d74-0605-4138-9a10-0ce36637040d", "metadata": { "tags": [] }, "source": [ "**Voila! Free of parsing errors.**" ] }, { "cell_type": "code", "execution_count": null, "id": "da63ce31-de79-4462-a1a9-b726b698c5ba", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" } }, "nbformat": 4, "nbformat_minor": 5 }