{ "cells": [ { "cell_type": "markdown", "id": "fdd7864c-93e6-4eb4-a923-b80d2ae4377d", "metadata": {}, "source": [ "# Structured Decoding with RELLM\n", "\n", "[RELLM](https://github.com/r2d4/rellm) is a library that wraps local Hugging Face pipeline models for structured decoding.\n", "\n", "It works by generating tokens one at a time. At each step, it masks tokens that don't conform to the provided partial regular expression.\n", "\n", "\n", "**Warning - this module is still experimental**" ] }, { "cell_type": "code", "execution_count": 1, "id": "1617e327-d9a2-4ab6-aa9f-30a3167a3393", "metadata": { "tags": [] }, "outputs": [], "source": [ "!pip install rellm > /dev/null" ] }, { "cell_type": "markdown", "id": "66bd89f1-8daa-433d-bb8f-5b0b3ae34b00", "metadata": {}, "source": [ "### Hugging Face Baseline\n", "\n", "First, let's establish a qualitative baseline by checking the output of the model without structured decoding." ] }, { "cell_type": "code", "execution_count": 2, "id": "d4d616ae-4d11-425f-b06c-c706d0386c68", "metadata": { "tags": [] }, "outputs": [], "source": [ "import logging\n", "logging.basicConfig(level=logging.ERROR)\n", "prompt = \"\"\"Human: \"What's the capital of the United States?\"\n", "AI Assistant:{\n", " \"action\": \"Final Answer\",\n", " \"action_input\": \"The capital of the United States is Washington D.C.\"\n", "}\n", "Human: \"What's the capital of Pennsylvania?\"\n", "AI Assistant:{\n", " \"action\": \"Final Answer\",\n", " \"action_input\": \"The capital of Pennsylvania is Harrisburg.\"\n", "}\n", "Human: \"What 2 + 5?\"\n", "AI Assistant:{\n", " \"action\": \"Final Answer\",\n", " \"action_input\": \"2 + 5 = 7.\"\n", "}\n", "Human: 'What's the capital of Maryland?'\n", "AI Assistant:\"\"\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "9148e4b8-d370-4c05-a873-c121b65057b5", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "generations=[[Generation(text=' \"What\\'s the capital of Maryland?\"\\n', generation_info=None)]] llm_output=None\n" ] } ], "source": [ "from transformers import pipeline\n", "from langchain.llms import HuggingFacePipeline\n", "\n", "hf_model = pipeline(\"text-generation\", model=\"cerebras/Cerebras-GPT-590M\", max_new_tokens=200)\n", "\n", "original_model = HuggingFacePipeline(pipeline=hf_model)\n", "\n", "generated = original_model.generate([prompt], stop=[\"Human:\"])\n", "print(generated)" ] }, { "cell_type": "markdown", "id": "b6e7b9cf-8ce5-4f87-b4bf-100321ad2dd1", "metadata": {}, "source": [ "***That's not so impressive, is it? It didn't answer the question and it didn't follow the JSON format at all! Let's try with the structured decoder.***" ] }, { "cell_type": "markdown", "id": "96115154-a90a-46cb-9759-573860fc9b79", "metadata": {}, "source": [ "## RELLM LLM Wrapper\n", "\n", "Let's try that again, now providing a regex to match the JSON structured format." ] }, { "cell_type": "code", "execution_count": 4, "id": "65c12e2a-bd7f-4cf0-8ef8-92cfa31c92ef", "metadata": {}, "outputs": [], "source": [ "import regex # Note this is the regex library NOT python's re stdlib module\n", "\n", "# We'll choose a regex that matches to a structured json string that looks like:\n", "# {\n", "# \"action\": \"Final Answer\",\n", "# \"action_input\": string or dict\n", "# }\n", "pattern = regex.compile(r'\\{\\s*\"action\":\\s*\"Final Answer\",\\s*\"action_input\":\\s*(\\{.*\\}|\"[^\"]*\")\\s*\\}\\nHuman:')" ] }, { "cell_type": "code", "execution_count": 5, "id": "de85b1f8-b405-4291-b6d0-4b2c56e77ad6", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{\"action\": \"Final Answer\",\n", " \"action_input\": \"The capital of Maryland is Baltimore.\"\n", "}\n", "\n" ] } ], "source": [ "from langchain.experimental.llms import RELLM\n", "\n", "model = RELLM(pipeline=hf_model, regex=pattern, max_new_tokens=200)\n", "\n", "generated = model.predict(prompt, stop=[\"Human:\"])\n", "print(generated)" ] }, { "cell_type": "markdown", "id": "32077d74-0605-4138-9a10-0ce36637040d", "metadata": { "tags": [] }, "source": [ "**Voila! Free of parsing errors.**" ] }, { "cell_type": "code", "execution_count": null, "id": "4bd208a1-779c-4c47-97d9-9115d15d441f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" } }, "nbformat": 4, "nbformat_minor": 5 }