forked from Archives/langchain
Add RELLM and JSONFormer experimental LLM decoding (#4185)
[RELLM](https://github.com/r2d4/rellm) is a library that wraps local HuggingFace pipeline models for structured decoding. RELLM works by generating tokens one at a time. At each step, it masks tokens that don't conform to the provided partial regular expression. [JSONFormer](https://github.com/1rgs/jsonformer) is a bit different, where it sequentially adds the keys then decodes each value directly
This commit is contained in:
parent
54f5523197
commit
d85b04be7f
@ -0,0 +1,280 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "fdd7864c-93e6-4eb4-a923-b80d2ae4377d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Structured Decoding with JSONFormer\n",
|
||||||
|
"\n",
|
||||||
|
"[JSONFormer](https://github.com/1rgs/jsonformer) is a library that wraps local HuggingFace pipeline models for structured decoding of a subset of the JSON Schema.\n",
|
||||||
|
"\n",
|
||||||
|
"It works by filling in the structure tokens and then sampling the content tokens from the model.\n",
|
||||||
|
"\n",
|
||||||
|
"**Warning - this module is still experimental**"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "1617e327-d9a2-4ab6-aa9f-30a3167a3393",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install --upgrade jsonformer > /dev/null"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "66bd89f1-8daa-433d-bb8f-5b0b3ae34b00",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### HuggingFace Baseline\n",
|
||||||
|
"\n",
|
||||||
|
"First, let's establish a qualitative baseline by checking the output of the model without structured decoding."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "d4d616ae-4d11-425f-b06c-c706d0386c68",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import logging\n",
|
||||||
|
"logging.basicConfig(level=logging.ERROR)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "1bdc7b60-6ffb-4099-9fa6-13efdfc45b04",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import Optional\n",
|
||||||
|
"from langchain.tools import tool\n",
|
||||||
|
"import os\n",
|
||||||
|
"import json\n",
|
||||||
|
"import requests\n",
|
||||||
|
"\n",
|
||||||
|
"HF_TOKEN = os.environ.get(\"HUGGINGFACE_API_KEY\")\n",
|
||||||
|
"\n",
|
||||||
|
"@tool\n",
|
||||||
|
"def ask_star_coder(query: str, \n",
|
||||||
|
" temperature: float = 1.0,\n",
|
||||||
|
" max_new_tokens: float = 250):\n",
|
||||||
|
" \"\"\"Query the BigCode StarCoder model about coding questions.\"\"\"\n",
|
||||||
|
" url = \"https://api-inference.huggingface.co/models/bigcode/starcoder\"\n",
|
||||||
|
" headers = {\n",
|
||||||
|
" \"Authorization\": f\"Bearer {HF_TOKEN}\",\n",
|
||||||
|
" \"content-type\": \"application/json\"\n",
|
||||||
|
" }\n",
|
||||||
|
" payload = {\n",
|
||||||
|
" \"inputs\": f\"{query}\\n\\nAnswer:\",\n",
|
||||||
|
" \"temperature\": temperature,\n",
|
||||||
|
" \"max_new_tokens\": int(max_new_tokens),\n",
|
||||||
|
" }\n",
|
||||||
|
" response = requests.post(url, headers=headers, data=json.dumps(payload))\n",
|
||||||
|
" response.raise_for_status()\n",
|
||||||
|
" return json.loads(response.content.decode(\"utf-8\"))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "d5522977-51e8-40eb-9403-8ab70b14908e",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"prompt = \"\"\"You must respond using JSON format, with a single action and single action input.\n",
|
||||||
|
"You may 'ask_star_coder' for help on coding problems.\n",
|
||||||
|
"\n",
|
||||||
|
"{arg_schema}\n",
|
||||||
|
"\n",
|
||||||
|
"EXAMPLES\n",
|
||||||
|
"----\n",
|
||||||
|
"Human: \"So what's all this about a GIL?\"\n",
|
||||||
|
"AI Assistant:{{\n",
|
||||||
|
" \"action\": \"ask_star_coder\",\n",
|
||||||
|
" \"action_input\": {{\"query\": \"What is a GIL?\", \"temperature\": 0.0, \"max_new_tokens\": 100}}\"\n",
|
||||||
|
"}}\n",
|
||||||
|
"Observation: \"The GIL is python's Global Interpreter Lock\"\n",
|
||||||
|
"Human: \"Could you please write a calculator program in LISP?\"\n",
|
||||||
|
"AI Assistant:{{\n",
|
||||||
|
" \"action\": \"ask_star_coder\",\n",
|
||||||
|
" \"action_input\": {{\"query\": \"Write a calculator program in LISP\", \"temperature\": 0.0, \"max_new_tokens\": 250}}\n",
|
||||||
|
"}}\n",
|
||||||
|
"Observation: \"(defun add (x y) (+ x y))\\n(defun sub (x y) (- x y ))\"\n",
|
||||||
|
"Human: \"What's the difference between an SVM and an LLM?\"\n",
|
||||||
|
"AI Assistant:{{\n",
|
||||||
|
" \"action\": \"ask_star_coder\",\n",
|
||||||
|
" \"action_input\": {{\"query\": \"What's the difference between SGD and an SVM?\", \"temperature\": 1.0, \"max_new_tokens\": 250}}\n",
|
||||||
|
"}}\n",
|
||||||
|
"Observation: \"SGD stands for stochastic gradient descent, while an SVM is a Support Vector Machine.\"\n",
|
||||||
|
"\n",
|
||||||
|
"BEGIN! Answer the Human's question as best as you are able.\n",
|
||||||
|
"------\n",
|
||||||
|
"Human: 'What's the difference between an iterator and an iterable?'\n",
|
||||||
|
"AI Assistant:\"\"\".format(arg_schema=ask_star_coder.args)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "9148e4b8-d370-4c05-a873-c121b65057b5",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" 'What's the difference between an iterator and an iterable?'\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from transformers import pipeline\n",
|
||||||
|
"from langchain.llms import HuggingFacePipeline\n",
|
||||||
|
"\n",
|
||||||
|
"hf_model = pipeline(\"text-generation\", model=\"cerebras/Cerebras-GPT-590M\", max_new_tokens=200)\n",
|
||||||
|
"\n",
|
||||||
|
"original_model = HuggingFacePipeline(pipeline=hf_model)\n",
|
||||||
|
"\n",
|
||||||
|
"generated = original_model.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n",
|
||||||
|
"print(generated)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b6e7b9cf-8ce5-4f87-b4bf-100321ad2dd1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"***That's not so impressive, is it? It didn't follow the JSON format at all! Let's try with the structured decoder.***"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "96115154-a90a-46cb-9759-573860fc9b79",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## JSONFormer LLM Wrapper\n",
|
||||||
|
"\n",
|
||||||
|
"Let's try that again, now providing a the Action input's JSON Schema to the model."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "30066ee7-9a92-4ae8-91bf-3262bf3c70c2",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"decoder_schema = {\n",
|
||||||
|
" \"title\": \"Decoding Schema\",\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"action\": {\"type\": \"string\", \"default\": ask_star_coder.name},\n",
|
||||||
|
" \"action_input\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": ask_star_coder.args,\n",
|
||||||
|
" }\n",
|
||||||
|
" }\n",
|
||||||
|
"} "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "0f7447fe-22a9-47db-85b9-7adf0f19307d",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.experimental.llms import JsonFormer\n",
|
||||||
|
"json_former = JsonFormer(json_schema=decoder_schema, pipeline=hf_model)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "d865e049-a5c3-4648-92db-8b912b7474ee",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\"action\": \"ask_star_coder\", \"action_input\": {\"query\": \"What's the difference between an iterator and an iter\", \"temperature\": 0.0, \"max_new_tokens\": 50.0}}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"results = json_former.predict(prompt, stop=[\"Observation:\", \"Human:\"])\n",
|
||||||
|
"print(results)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "32077d74-0605-4138-9a10-0ce36637040d",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"**Voila! Free of parsing errors.**"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "da63ce31-de79-4462-a1a9-b726b698c5ba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
208
docs/modules/models/llms/integrations/rellm_experimental.ipynb
Normal file
208
docs/modules/models/llms/integrations/rellm_experimental.ipynb
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "fdd7864c-93e6-4eb4-a923-b80d2ae4377d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Structured Decoding with RELLM\n",
|
||||||
|
"\n",
|
||||||
|
"[RELLM](https://github.com/r2d4/rellm) is a library that wraps local HuggingFace pipeline models for structured decoding.\n",
|
||||||
|
"\n",
|
||||||
|
"It works by generating tokens one at a time. At each step, it masks tokens that don't conform to the provided partial regular expression.\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"**Warning - this module is still experimental**"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "1617e327-d9a2-4ab6-aa9f-30a3167a3393",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install rellm > /dev/null"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "66bd89f1-8daa-433d-bb8f-5b0b3ae34b00",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### HuggingFace Baseline\n",
|
||||||
|
"\n",
|
||||||
|
"First, let's establish a qualitative baseline by checking the output of the model without structured decoding."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "d4d616ae-4d11-425f-b06c-c706d0386c68",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import logging\n",
|
||||||
|
"logging.basicConfig(level=logging.ERROR)\n",
|
||||||
|
"prompt = \"\"\"Human: \"What's the capital of the United States?\"\n",
|
||||||
|
"AI Assistant:{\n",
|
||||||
|
" \"action\": \"Final Answer\",\n",
|
||||||
|
" \"action_input\": \"The capital of the United States is Washington D.C.\"\n",
|
||||||
|
"}\n",
|
||||||
|
"Human: \"What's the capital of Pennsylvania?\"\n",
|
||||||
|
"AI Assistant:{\n",
|
||||||
|
" \"action\": \"Final Answer\",\n",
|
||||||
|
" \"action_input\": \"The capital of Pennsylvania is Harrisburg.\"\n",
|
||||||
|
"}\n",
|
||||||
|
"Human: \"What 2 + 5?\"\n",
|
||||||
|
"AI Assistant:{\n",
|
||||||
|
" \"action\": \"Final Answer\",\n",
|
||||||
|
" \"action_input\": \"2 + 5 = 7.\"\n",
|
||||||
|
"}\n",
|
||||||
|
"Human: 'What's the capital of Maryland?'\n",
|
||||||
|
"AI Assistant:\"\"\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "9148e4b8-d370-4c05-a873-c121b65057b5",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"generations=[[Generation(text=' \"What\\'s the capital of Maryland?\"\\n', generation_info=None)]] llm_output=None\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from transformers import pipeline\n",
|
||||||
|
"from langchain.llms import HuggingFacePipeline\n",
|
||||||
|
"\n",
|
||||||
|
"hf_model = pipeline(\"text-generation\", model=\"cerebras/Cerebras-GPT-590M\", max_new_tokens=200)\n",
|
||||||
|
"\n",
|
||||||
|
"original_model = HuggingFacePipeline(pipeline=hf_model)\n",
|
||||||
|
"\n",
|
||||||
|
"generated = original_model.generate([prompt], stop=[\"Human:\"])\n",
|
||||||
|
"print(generated)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b6e7b9cf-8ce5-4f87-b4bf-100321ad2dd1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"***That's not so impressive, is it? It didn't answer the question and it didn't follow the JSON format at all! Let's try with the structured decoder.***"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "96115154-a90a-46cb-9759-573860fc9b79",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## RELLM LLM Wrapper\n",
|
||||||
|
"\n",
|
||||||
|
"Let's try that again, now providing a regex to match the JSON structured format."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "65c12e2a-bd7f-4cf0-8ef8-92cfa31c92ef",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import regex # Note this is the regex library NOT python's re stdlib module\n",
|
||||||
|
"\n",
|
||||||
|
"# We'll choose a regex that matches to a structured json string that looks like:\n",
|
||||||
|
"# {\n",
|
||||||
|
"# \"action\": \"Final Answer\",\n",
|
||||||
|
"# \"action_input\": string or dict\n",
|
||||||
|
"# }\n",
|
||||||
|
"pattern = regex.compile(r'\\{\\s*\"action\":\\s*\"Final Answer\",\\s*\"action_input\":\\s*(\\{.*\\}|\"[^\"]*\")\\s*\\}\\nHuman:')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "de85b1f8-b405-4291-b6d0-4b2c56e77ad6",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\"action\": \"Final Answer\",\n",
|
||||||
|
" \"action_input\": \"The capital of Maryland is Baltimore.\"\n",
|
||||||
|
"}\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain.experimental.llms import RELLM\n",
|
||||||
|
"\n",
|
||||||
|
"model = RELLM(pipeline=hf_model, regex=pattern, max_new_tokens=200)\n",
|
||||||
|
"\n",
|
||||||
|
"generated = model.predict(prompt, stop=[\"Human:\"])\n",
|
||||||
|
"print(generated)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "32077d74-0605-4138-9a10-0ce36637040d",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"**Voila! Free of parsing errors.**"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "4bd208a1-779c-4c47-97d9-9115d15d441f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.2"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
6
langchain/experimental/llms/__init__.py
Normal file
6
langchain/experimental/llms/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
"""Experimental LLM wrappers."""
|
||||||
|
|
||||||
|
from langchain.experimental.llms.jsonformer_decoder import JsonFormer
|
||||||
|
from langchain.experimental.llms.rellm_decoder import RELLM
|
||||||
|
|
||||||
|
__all__ = ["RELLM", "JsonFormer"]
|
60
langchain/experimental/llms/jsonformer_decoder.py
Normal file
60
langchain/experimental/llms/jsonformer_decoder.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Experimental implementation of jsonformer wrapped LLM."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, cast
|
||||||
|
|
||||||
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import jsonformer
|
||||||
|
|
||||||
|
|
||||||
|
def import_jsonformer() -> jsonformer:
|
||||||
|
"""Lazily import jsonformer."""
|
||||||
|
try:
|
||||||
|
import jsonformer
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import jsonformer python package. "
|
||||||
|
"Please install it with `pip install jsonformer`."
|
||||||
|
)
|
||||||
|
return jsonformer
|
||||||
|
|
||||||
|
|
||||||
|
class JsonFormer(HuggingFacePipeline):
|
||||||
|
json_schema: dict = Field(..., description="The JSON Schema to complete.")
|
||||||
|
max_new_tokens: int = Field(
|
||||||
|
default=200, description="Maximum number of new tokens to generate."
|
||||||
|
)
|
||||||
|
debug: bool = Field(default=False, description="Debug mode.")
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def check_jsonformer_installation(cls, values: dict) -> dict:
|
||||||
|
import_jsonformer()
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
jsonformer = import_jsonformer()
|
||||||
|
from transformers import Text2TextGenerationPipeline
|
||||||
|
|
||||||
|
pipeline = cast(Text2TextGenerationPipeline, self.pipeline)
|
||||||
|
|
||||||
|
model = jsonformer.Jsonformer(
|
||||||
|
model=pipeline.model,
|
||||||
|
tokenizer=pipeline.tokenizer,
|
||||||
|
json_schema=self.json_schema,
|
||||||
|
prompt=prompt,
|
||||||
|
max_number_tokens=self.max_new_tokens,
|
||||||
|
debug=self.debug,
|
||||||
|
)
|
||||||
|
text = model()
|
||||||
|
return json.dumps(text)
|
67
langchain/experimental/llms/rellm_decoder.py
Normal file
67
langchain/experimental/llms/rellm_decoder.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""Experimental implementation of RELLM wrapped LLM."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List, Optional, cast
|
||||||
|
|
||||||
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import rellm
|
||||||
|
from regex import Pattern as RegexPattern
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from regex import Pattern as RegexPattern
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def import_rellm() -> rellm:
|
||||||
|
"""Lazily import rellm."""
|
||||||
|
try:
|
||||||
|
import rellm
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import rellm python package. "
|
||||||
|
"Please install it with `pip install rellm`."
|
||||||
|
)
|
||||||
|
return rellm
|
||||||
|
|
||||||
|
|
||||||
|
class RELLM(HuggingFacePipeline):
|
||||||
|
regex: RegexPattern = Field(..., description="The structured format to complete.")
|
||||||
|
max_new_tokens: int = Field(
|
||||||
|
default=200, description="Maximum number of new tokens to generate."
|
||||||
|
)
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def check_rellm_installation(cls, values: dict) -> dict:
|
||||||
|
import_rellm()
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
rellm = import_rellm()
|
||||||
|
from transformers import Text2TextGenerationPipeline
|
||||||
|
|
||||||
|
pipeline = cast(Text2TextGenerationPipeline, self.pipeline)
|
||||||
|
|
||||||
|
text = rellm.complete_re(
|
||||||
|
prompt,
|
||||||
|
self.regex,
|
||||||
|
tokenizer=pipeline.tokenizer,
|
||||||
|
model=pipeline.model,
|
||||||
|
max_new_tokens=self.max_new_tokens,
|
||||||
|
)
|
||||||
|
if stop is not None:
|
||||||
|
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||||
|
# stop tokens when making calls to huggingface_hub.
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
Loading…
Reference in New Issue
Block a user