Merge pull request #14 from VowpalWabbit/notebook_fix

Notebook fix
This commit is contained in:
olgavrou 2023-09-05 12:15:52 -04:00 committed by GitHub
commit 15d33a144d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,15 +4,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Learned prompt variable injection via rl chain\n",
"# Learned Prompt Variable Injection via RL Chain\n",
"\n",
"The rl_chain (reinforcement learning chain) is used primarily for prompt variable injection: when we want to enhance a prompt with a value but we are not sure which of the available variable values will make the prompt achieve what we want.\n",
"LLM prompts can be enhanced by injecting specific terms into template sentences. Selecting the right terms is crucial for obtaining high-quality responses. This notebook introduces automated prompt engineering through term injection using Reinforcement Learning with VowpalWabbit.\n",
"\n",
"It provides a way to learn a specific prompt engineering policy without fine tuning the underlying foundational model.\n",
"The rl_chain (reinforcement learning chain) provides a way to automatically determine the best terms to inject without the need for fine-tuning the underlying foundational model.\n",
"\n",
"The example layed out below is trivial and a strong llm could make a good variable selection and injection without the intervention of this chain, but it is perfect for showcasing the chain's usage. Advanced options and explanations are provided at the end.\n",
"For illustration, consider the scenario of a meal delivery service. We use LangChain to ask customers, like Tom, about their dietary preferences and recommend suitable meals from our extensive menu. The rl_chain selects a meal based on user preferences, injects it into a prompt template, and forwards the prompt to an LLM. The LLM's response, which is a personalized recommendation, is then returned to the user.\n",
"\n",
"The goal of this example scenario is for the chain to select a meal based on the user declared preferences, and inject the meal into the prompt template. The final prompt will then be sent to the llm of choice and the llm output will be returned to the user."
"The example laid out below is a toy example to demonstrate the applicability of the concept. Advanced options and explanations are provided at the end."
]
},
{
@ -35,35 +35,22 @@
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"\\n\\nYes, I'm ready.\""
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# pick and configure the LLM of your choice\n",
"\n",
"from langchain.llms import OpenAI\n",
"llm = OpenAI(engine=\"text-davinci-003\")\n",
"\n",
"llm.predict(\"are you ready?\")"
"llm = OpenAI(engine=\"text-davinci-003\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Intialize the rl chain with provided defaults\n",
"##### Intialize the RL chain with provided defaults\n",
"\n",
"The prompt template which will be used to query the LLM needs to be defined.\n",
"It can be anything, but here `{meal}` is being used and is going to be replaced by one of the meals above, the rl chain will try to pick and inject the best meal\n"
"It can be anything, but here `{meal}` is being used and is going to be replaced by one of the meals above, the RL chain will try to pick and inject the best meal\n"
]
},
{
@ -97,7 +84,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Next the rl chain's PickBest chain is being initialized. We must provide the llm of choice and the defined prompt. As the name indicates, the chain's goal is to Pick the Best of the meals that will be provided, based on some criteria. "
"Next the RL chain's PickBest chain is being initialized. We must provide the llm of choice and the defined prompt. As the name indicates, the chain's goal is to Pick the Best of the meals that will be provided, based on some criteria. "
]
},
{
@ -156,18 +143,18 @@
"source": [
"## What is the chain doing\n",
"\n",
"What is happening behind the scenes here is that the rl chain will\n",
"Here's a step-by-step breakdown of the RL chain's operations:\n",
"\n",
"- take the meals\n",
"- take the user and their preference\n",
"- based on the user and their preference (context) it will select a meal\n",
"- it will auto-evaluate if that meal selection was good or bad\n",
"- it will finally inject the meal into the prompt and query the llm\n",
"- the user will get the llm response back\n",
"1. Accept the list of meals.\n",
"2. Consider the user and their dietary preferences.\n",
"3. Based on this context, select an appropriate meal.\n",
"4. Automatically evaluate the appropriateness of the meal choice.\n",
"5. Inject the selected meal into the prompt and submit it to the LLM.\n",
"6. Return the LLM's response to the user.\n",
"\n",
"Now, the way the chain is doing this is that it is learning a contextual bandit rl model that is trained to make good selections (specifially the [VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) ML library is being used).\n",
"Technically, the chain achieves this by employing a contextual bandit reinforcement learning model, specifically utilizing the [VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) ML library.\n",
"\n",
"Since this rl model will be untrained when we first start, it might make a random selection that doesn't fit the user and their preferences. But if we give it time to learn the user and their preferences, it should start to make better selections (or quickly learn a good one and just pick that!)."
"Initially, since the RL model is untrained, it might opt for random selections that don't necessarily align with a user's preferences. However, as it gains more exposure to the user's choices and feedback, it should start to make better selections (or quickly learn a good one and just pick that!).\n"
]
},
{
@ -213,6 +200,8 @@
"source": [
"## How is the chain learning\n",
"\n",
"It's important to note that while the RL model can make sophisticated selections, it doesn't inherently recognize concepts like \"vegetarian\" or understand that \"beef enchiladas\" aren't vegetarian-friendly. Instead, it leverages the LLM to ground its choices in common sense.\n",
"\n",
"The way the chain is learning that Tom prefers veggetarian meals is via an AutoSelectionScorer that is built into the chain. The scorer will call the LLM again and ask it to evaluate the selection (`ToSelectFrom`) using the information wrapped in (`BasedOn`).\n",
"\n",
"You can set `langchain.debug=True` if you want to see the details of the auto-scorer, but you can also define the scoring prompt yourself."
@ -275,7 +264,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
@ -309,7 +298,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
@ -355,7 +344,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
@ -386,7 +375,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
@ -410,42 +399,48 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"for i in range(40):\n",
"for _ in range(20):\n",
" try:\n",
" if i % 2:\n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" else:\n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" \n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The RL chain converges to the fact that Anna prefers beef and Tom is vegetarian. The random chain picks at random, and so will send beef to vegetarians half the time."
]
},
{
"cell_type": "code",
"execution_count": 25,
@ -495,7 +490,7 @@
"source": [
"## Advanced options\n",
"\n",
"The rl chain is highly configurable in order to be able to adjust to various selection scenarios. If you want to learn more about the ML library that powers it please take a look at tutorials [here](https://vowpalwabbit.org/)\n"
"The RL chain is highly configurable in order to be able to adjust to various selection scenarios. If you want to learn more about the ML library that powers it please take a look at tutorials [here](https://vowpalwabbit.org/)\n"
]
},
{