notebook fix

pull/10242/head
olgavrou 1 year ago
parent 235dacc74a
commit 7a4387c60d

@ -4,15 +4,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Learned prompt variable injection via rl chain\n",
"# Learned Prompt Variable Injection via RL Chain\n",
"\n",
"The rl_chain (reinforcement learning chain) is used primarily for prompt variable injection: when we want to enhance a prompt with a value but we are not sure which of the available variable values will make the prompt achieve what we want.\n",
"LLM prompts can be enhanced by injecting specific terms into template sentences. Selecting the right terms is crucial for obtaining high-quality responses. This notebook introduces automated prompt engineering through term injection using Reinforcement Learning with VowpalWabbit.\n",
"\n",
"It provides a way to learn a specific prompt engineering policy without fine tuning the underlying foundational model.\n",
"The rl_chain (reinforcement learning chain) provides a way to automatically determine the best terms to inject without the need for fine-tuning the underlying foundational model.\n",
"\n",
"The example layed out below is trivial and a strong llm could make a good variable selection and injection without the intervention of this chain, but it is perfect for showcasing the chain's usage. Advanced options and explanations are provided at the end.\n",
"For illustration, consider the scenario of a meal delivery service. We use LangChain to ask customers, like Tom, about their dietary preferences and recommend suitable meals from our extensive menu. The rl_chain selects a meal based on user preferences, injects it into a prompt template, and forwards the prompt to an LLM. The LLM's response, which is a personalized recommendation, is then returned to the user.\n",
"\n",
"The goal of this example scenario is for the chain to select a meal based on the user declared preferences, and inject the meal into the prompt template. The final prompt will then be sent to the llm of choice and the llm output will be returned to the user."
"The example laid out below is a toy example to demonstrate the applicability of the concept. Advanced options and explanations are provided at the end."
]
},
{
@ -35,25 +35,12 @@
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"\\n\\nYes, I'm ready.\""
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# pick and configure the LLM of your choice\n",
"\n",
"from langchain.llms import OpenAI\n",
"llm = OpenAI(engine=\"text-davinci-003\")\n",
"\n",
"llm.predict(\"are you ready?\")"
"llm = OpenAI(engine=\"text-davinci-003\")\n"
]
},
{
@ -156,18 +143,18 @@
"source": [
"## What is the chain doing\n",
"\n",
"What is happening behind the scenes here is that the rl chain will\n",
"Here's a step-by-step breakdown of the RL chain's operations:\n",
"\n",
"- take the meals\n",
"- take the user and their preference\n",
"- based on the user and their preference (context) it will select a meal\n",
"- it will auto-evaluate if that meal selection was good or bad\n",
"- it will finally inject the meal into the prompt and query the llm\n",
"- the user will get the llm response back\n",
"1. Accept the list of meals.\n",
"2. Consider the user and their dietary preferences.\n",
"3. Based on this context, select an appropriate meal.\n",
"4. Automatically evaluate the appropriateness of the meal choice.\n",
"5. Inject the selected meal into the prompt and submit it to the LLM.\n",
"6. Return the LLM's response to the user.\n",
"\n",
"Now, the way the chain is doing this is that it is learning a contextual bandit rl model that is trained to make good selections (specifially the [VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) ML library is being used).\n",
"Technically, the chain achieves this by employing a contextual bandit reinforcement learning model, specifically utilizing the [VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) ML library.\n",
"\n",
"Since this rl model will be untrained when we first start, it might make a random selection that doesn't fit the user and their preferences. But if we give it time to learn the user and their preferences, it should start to make better selections (or quickly learn a good one and just pick that!)."
"Initially, since the RL model is untrained, it might opt for random selections that don't necessarily align with a user's preferences. However, as it gains more exposure to the user's choices and feedback, it should start to make better selections (or quickly learn a good one and just pick that!).\n"
]
},
{
@ -213,6 +200,8 @@
"source": [
"## How is the chain learning\n",
"\n",
"It's important to note that while the RL model can make sophisticated selections, it doesn't inherently recognize concepts like \"vegetarian\" or understand that \"beef enchiladas\" aren't vegetarian-friendly. Instead, it leverages the LLM to ground its choices in common sense.\n",
"\n",
"The way the chain is learning that Tom prefers veggetarian meals is via an AutoSelectionScorer that is built into the chain. The scorer will call the LLM again and ask it to evaluate the selection (`ToSelectFrom`) using the information wrapped in (`BasedOn`).\n",
"\n",
"You can set `langchain.debug=True` if you want to see the details of the auto-scorer, but you can also define the scoring prompt yourself."
@ -275,7 +264,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
@ -309,7 +298,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
@ -355,7 +344,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
@ -386,7 +375,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
@ -410,42 +399,48 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"for i in range(40):\n",
"for _ in range(20):\n",
" try:\n",
" if i % 2:\n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" else:\n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" \n",
" chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" user = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The RL chain converges to the fact that Anna prefers beef and Tom is vegetarian. The random chain picks at random, and so will send beef to vegetarians half the time."
]
},
{
"cell_type": "code",
"execution_count": 25,

Loading…
Cancel
Save