"LLM prompts can be enhanced by injecting specific terms into template sentences. Selecting the right terms is crucial for obtaining high-quality responses. This notebook introduces automated prompt engineering through term injection using Reinforcement Learning with VowpalWabbit.\n",
"\n",
"The rl_chain (reinforcement learning chain) provides a way to automatically determine the best terms to inject without the need for fine-tuning the underlying foundational model.\n",
"\n",
"For illustration, consider the scenario of a meal delivery service. We use LangChain to ask customers, like Tom, about their dietary preferences and recommend suitable meals from our extensive menu. The rl_chain selects a meal based on user preferences, injects it into a prompt template, and forwards the prompt to an LLM. The LLM's response, which is a personalized recommendation, is then returned to the user.\n",
"\n",
"The example laid out below is a toy example to demonstrate the applicability of the concept. Advanced options and explanations are provided at the end."
"The prompt template which will be used to query the LLM needs to be defined.\n",
"It can be anything, but here `{meal}` is being used and is going to be replaced by one of the meals above, the RL chain will try to pick and inject the best meal\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"\n",
"# here I am using the variable meal which will be replaced by one of the meals above\n",
"# and some variables like user, preference, and text_to_personalize which I will provide at chain run time\n",
"\n",
"PROMPT_TEMPLATE = \"\"\"Here is the description of a meal: \"{meal}\".\n",
"\n",
"Embed the meal into the given text: \"{text_to_personalize}\".\n",
"\n",
"Prepend a personalized message including the user's name \"{user}\" \n",
"Next the RL chain's PickBest chain is being initialized. We must provide the llm of choice and the defined prompt. As the name indicates, the chain's goal is to Pick the Best of the meals that will be provided, based on some criteria. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import langchain_experimental.rl_chain as rl_chain\n",
"Once the chain is setup I am going to call it with the meals I want to be selected from, and some context based on which the chain will select a meal."
"Hey Tom! We've got a special treat for you this week - our master chefs have cooked up a delicious One-Pan Tortelonni Bake with peppers and onions, perfect for any Vegetarian who is ok with regular dairy! We know you'll love it!\n"
]
}
],
"source": [
"print(response[\"response\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What is the chain doing\n",
"\n",
"Here's a step-by-step breakdown of the RL chain's operations:\n",
"\n",
"1. Accept the list of meals.\n",
"2. Consider the user and their dietary preferences.\n",
"3. Based on this context, select an appropriate meal.\n",
"4. Automatically evaluate the appropriateness of the meal choice.\n",
"5. Inject the selected meal into the prompt and submit it to the LLM.\n",
"6. Return the LLM's response to the user.\n",
"\n",
"Technically, the chain achieves this by employing a contextual bandit reinforcement learning model, specifically utilizing the [VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) ML library.\n",
"\n",
"Initially, since the RL model is untrained, it might opt for random selections that don't necessarily align with a user's preferences. However, as it gains more exposure to the user's choices and feedback, it should start to make better selections (or quickly learn a good one and just pick that!).\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hey Tom! We know you love vegetarian dishes and that regular dairy is ok, so this week's specialty dish is perfect for you! Our master chefs have created a delicious Chicken Flatbread with red sauce - a unique Italian-Mexican fusion that we know you'll love. Enjoy!\n",
"\n",
"Hey Tom, this week's specialty dish is a delicious Mexican-Greek fusion of Beef Enchiladas with Feta cheese to suit your preference of 'Vegetarian' with 'regular dairy is ok'. Our master chefs believe you will love it!\n",
"\n",
"Hey Tom! Our master chefs have cooked up something special this week - a Mexican-Greek fusion of Beef Enchiladas with Feta cheese - and we know you'll love it as a vegetarian-friendly option with regular dairy included. Enjoy!\n",
"\n",
"Hey Tom! We've got the perfect meal for you this week - our delicious veggie sweet potato quesadillas with vegan cheese, made with the freshest ingredients. Even if you usually opt for regular dairy, we think you'll love this vegetarian dish!\n",
"\n",
"Hey Tom! Our master chefs have outdone themselves this week with a special dish just for you - Chicken Flatbreads with red sauce. It's an Italian-Mexican fusion that's sure to tantalize your taste buds, and it's totally vegetarian friendly with regular dairy is ok. Enjoy!\n",
"It's important to note that while the RL model can make sophisticated selections, it doesn't inherently recognize concepts like \"vegetarian\" or understand that \"beef enchiladas\" aren't vegetarian-friendly. Instead, it leverages the LLM to ground its choices in common sense.\n",
"The way the chain is learning that Tom prefers vegetarian meals is via an AutoSelectionScorer that is built into the chain. The scorer will call the LLM again and ask it to evaluate the selection (`ToSelectFrom`) using the information wrapped in (`BasedOn`).\n",
"If you want to examine the score and other selection metadata you can by examining the metadata object returned by the chain"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hey Tom, this week's meal is something special! Our chefs have prepared a delicious One-Pan Tortelonni Bake with peppers and onions - vegetarian friendly and made with regular dairy, so you can enjoy it without worry. We know you'll love it!\n",
"In a more realistic scenario it is likely that you have a well defined scoring function for what was selected. For example, you might be doing few-shot prompting and want to select prompt examples for a natural language to sql translation task. In that case the scorer could be: did the sql that was generated run in an sql engine? In that case you want to plugin a scoring function. In the example below I will just check if the meal picked was vegetarian or not."
" if \"Vegetarian\" in event.based_on[\"preference\"]:\n",
" if \"Chicken\" in selected_meal or \"Beef\" in selected_meal:\n",
" return 0.0\n",
" else:\n",
" return 1.0\n",
" else:\n",
" if \"Chicken\" in selected_meal or \"Beef\" in selected_meal:\n",
" return 1.0\n",
" else:\n",
" return 0.0\n",
" else:\n",
" raise NotImplementedError(\"I don't know how to score this user\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"chain = rl_chain.PickBest.from_llm(\n",
" llm=llm,\n",
" prompt=PROMPT,\n",
" selection_scorer=CustomSelectionScorer(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'user': ['Tom'], 'preference': ['Vegetarian', 'regular dairy is ok']}\n",
"{'meal': ['Beef Enchiladas with Feta cheese. Mexican-Greek fusion', 'Chicken Flatbreads with red sauce. Italian-Mexican fusion', 'Veggie sweet potato quesadillas with vegan cheese', 'One-Pan Tortelonni bake with peppers and onions']}\n",
"selected meal: Veggie sweet potato quesadillas with vegan cheese\n"
"You can track the chains progress by using the metrics mechanism provided. I am going to expand the users to Tom and Anna, and extend the scoring function. I am going to initialize two chains, one with the default learning policy and one with a built-in random policy (i.e. selects a meal randomly), and plot their scoring progress."
"The RL chain converges to the fact that Anna prefers beef and Tom is vegetarian. The random chain picks at random, and so will send beef to vegetarians half the time."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The final average score for the default policy, calculated over a rolling window, is: 1.0\n",
"The final average score for the random policy, calculated over a rolling window, is: 0.6\n"
"There is a bit of randomness involved in the rl_chain's selection since the chain explores the selection space in order to learn the world as best as it can (see details of default exploration algorithm used [here](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-Exploration-with-SquareCB)), but overall, default chain policy should be doing better than random as it learns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Advanced options\n",
"\n",
"The RL chain is highly configurable in order to be able to adjust to various selection scenarios. If you want to learn more about the ML library that powers it please take a look at tutorials [here](https://vowpalwabbit.org/)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Section | Description | Example / Usage |\n",
"|---------|-------------|-----------------|\n",
"| [**Change Chain Logging Level**](#change-chain-logging-level) | Change the logging level for the RL chain. | `logger.setLevel(logging.INFO)` |\n",
"| [**Featurization**](#featurization) | Adjusts the input to the RL chain. Can set auto-embeddings ON for more complex embeddings. | `chain = rl_chain.PickBest.from_llm(auto_embed=True, [...])` |\n",
"| [**Learned Policy to Learn Asynchronously**](#learned-policy-to-learn-asynchronously) | Score asynchronously if user input is needed for scoring. | `chain.update_with_delayed_score(score=<the score>, chain_response=response)` |\n",
"| [**Store Progress of Learned Policy**](#store-progress-of-learned-policy) | Option to store the progress of the variable injection learned policy. | `chain.save_progress()` |\n",
"| [**Stop Learning of Learned Policy**](#stop-learning-of-learned-policy) | Toggle the RL chain's learned policy updates ON/OFF. | `chain.deactivate_selection_scorer()` |\n",
"| [**Set a Different Policy**](#set-a-different-policy) | Choose between different policies: default, random, or custom. | Custom policy creation at chain creation time. |\n",
"| [**Different Exploration Algorithms and Options for Default Learned Policy**](#different-exploration-algorithms-and-options-for-the-default-learned-policy) | Set different exploration algorithms and hyperparameters for `VwPolicy`. | `vw_cmd = [\"--cb_explore_adf\", \"--quiet\", \"--squarecb\", \"--interactions=::\"]` |\n",
"| [**Learn Policy's Data Logs**](#learned-policys-data-logs) | Store and examine `VwPolicy`'s data logs. | `chain = rl_chain.PickBest.from_llm(vw_logs=<path to log FILE>, [...])` |\n",
"| [**Other Advanced Featurization Options**](#other-advanced-featurization-options) | Specify advanced featurization options for the RL chain. | `age = rl_chain.BasedOn(\"age:32\")` |\n",
"| [**More Info on Auto or Custom SelectionScorer**](#more-info-on-auto-or-custom-selectionscorer) | Dive deeper into how selection scoring is determined. | `selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, scoring_criteria_template_str=scoring_criteria_template)` |\n",
"\n",
"### change chain logging level\n",
"\n",
"```\n",
"import logging\n",
"logger = logging.getLogger(\"rl_chain\")\n",
"logger.setLevel(logging.INFO)\n",
"```\n",
"\n",
"### featurization\n",
"\n",
"#### auto_embed\n",
"\n",
"By default the input to the rl chain (`ToSelectFrom`, `BasedOn`) is not tampered with. This might not be sufficient featurization, so based on how complex the scenario is you can set auto-embeddings to ON\n",
"This will produce more complex embeddings and featurizations of the inputs, likely accelerating RL chain learning, albeit at the cost of increased runtime.\n",
"\n",
"By default, [sbert.net's sentence_transformers's ](https://www.sbert.net/docs/pretrained_models.html#model-overview) `all-mpnet-base-v2` model will be used for these embeddings but you can set a different embeddings model by initializing the chain with it as shown in this example. You could also set an entirely different embeddings encoding object, as long as it has an `encode()` function that returns a list of the encodings.\n",
"Another option is to define what inputs you think should be embedded manually:\n",
"- `auto_embed = False`\n",
"- Can wrap individual variables in `rl_chain.Embed()` or `rl_chain.EmbedAndKeep()` e.g. `user = rl_chain.BasedOn(rl_chain.Embed(\"Tom\"))`\n",
"\n",
"#### custom featurization\n",
"\n",
"Another final option is to define and set a custom featurization/embedder class that returns a valid input for the learned policy.\n",
"\n",
"## learned policy to learn asynchronously\n",
"\n",
"If to score the result you need input from the user (e.g. my application showed Tom the selected meal and Tom clicked on it, but Anna did not), then the scoring can be done asynchronously. The way to do that is:\n",
"\n",
"- set `selection_scorer=None` on the chain creation OR call `chain.deactivate_selection_scorer()`\n",
"- call the chain for a specific input\n",
"- keep the chain's response (`response = chain.run([...])`)\n",
"- once you have determined the score of the response/chain selection call the chain with it: `chain.update_with_delayed_score(score=<the score>, chain_response=response)`\n",
"\n",
"### store progress of learned policy\n",
"\n",
"Since the variable injection learned policy evolves over time, there is the option to store its progress and continue learning. This can be done by calling:\n",
"\n",
"`chain.save_progress()`\n",
"\n",
"which will store the rl chain's learned policy in a file called `latest.vw`. It will also store it in a file with a timestamp. That way, if `save_progress()` is called more than once, multiple checkpoints will be created, but the latest one will always be in `latest.vw`\n",
"\n",
"Next time the chain is loaded, the chain will look for a file called `latest.vw` and if the file exists it will be loaded into the chain and the learning will continue from there.\n",
"\n",
"By default the rl chain model checkpoints will be stored in the current directory but you can specify the save/load location at chain creation time:\n",
"\n",
"`chain = rl_chain.PickBest.from_llm(model_save_dir=<path to dir>, [...])`\n",
"\n",
"### stop learning of learned policy\n",
"\n",
"If you want the rl chain's learned policy to stop updating you can turn it off/on:\n",
"\n",
"`chain.deactivate_selection_scorer()` and `chain.activate_selection_scorer()`\n",
"\n",
"### set a different policy\n",
"\n",
"There are two policies currently available:\n",
"\n",
"- default policy: `VwPolicy` which learns a [Vowpal Wabbit](https://github.com/VowpalWabbit/vowpal_wabbit) [Contextual Bandit](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-algorithms) model\n",
"\n",
"- random policy: `RandomPolicy` which doesn't learn anything and just selects a value randomly. this policy can be used to compare other policies with a random baseline one.\n",
"\n",
"- custom policies: a custom policy could be created and set at chain creation time\n",
"\n",
"### different exploration algorithms and options for the default learned policy\n",
"\n",
"The default `VwPolicy` is initialized with some default arguments. The default exploration algorithm is [SquareCB](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-Exploration-with-SquareCB) but other Contextual Bandit exploration algorithms can be set, and other hyper parameters can be tuned (see [here](https://vowpalwabbit.org/docs/vowpal_wabbit/python/9.6.0/command_line_args.html) for available options).\n",
"The `VwPolicy`'s data files can be stored and examined or used to do [off policy evaluation](https://vowpalwabbit.org/docs/vowpal_wabbit/python/latest/tutorials/off_policy_evaluation.html) for hyper parameter tuning.\n",
"\n",
"The way to do this is to set a log file path to `vw_logs` on chain creation:\n",
"\n",
"`chain = rl_chain.PickBest.from_llm(vw_logs=<path to log FILE>, [...])`\n",
"user = rl_chain.BasedOn([\"Tom Joe\", \"age:32\", \"state of california\"])\n",
"```\n",
"\n",
"there is no dictionary provided since multiple variables can be supplied wrapped in `BasedOn`\n",
"\n",
"Storing the data logs into a file allows the examination of what different inputs do to the data format.\n",
"\n",
"### More info on Auto or Custom SelectionScorer\n",
"\n",
"It is very important to get the selection scorer right since the policy uses it to learn. It determines what is called the reward in reinforcement learning, and more specifically in our Contextual Bandits setting.\n",
"\n",
"The general advice is to keep the score between [0, 1], 0 being the worst selection, 1 being the best selection from the available `ToSelectFrom` variables, based on the `BasedOn` variables, but should be adjusted if the need arises.\n",
"\n",
"In the examples provided above, the AutoSelectionScorer is set mostly to get users started but in real world scenarios it will most likely not be an adequate scorer function.\n",
"\n",
"The example also provided the option to change part of the scoring prompt template that the AutoSelectionScorer used to determine whether a selection was good or not:\n",
"\n",
"```\n",
"scoring_criteria_template = \"Given {preference} rank how good or bad this selection is {meal}\"\n",
"However, if needed, a FULL scoring prompt can also be provided:\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:PickBest] Entering Chain run with input:\n",
"\u001b[0m[inputs]\n",
"\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:PickBest > 2:chain:LLMChain] Entering Chain run with input:\n",
"\u001b[0m[inputs]\n",
"\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:PickBest > 2:chain:LLMChain > 3:llm:OpenAI] Entering LLM run with input:\n",
"\u001b[0m{\n",
" \"prompts\": [\n",
" \"Here is the description of a meal: \\\"Chicken Flatbreads with red sauce. Italian-Mexican fusion\\\".\\n\\nEmbed the meal into the given text: \\\"This is the weeks specialty dish, our master chefs believe you will love it!\\\".\\n\\nPrepend a personalized message including the user's name \\\"Tom\\\" \\n and their preference \\\"['Vegetarian', 'regular dairy is ok']\\\".\\n\\nMake it sound good.\"\n",
" ]\n",
"}\n",
"\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:PickBest > 2:chain:LLMChain > 3:llm:OpenAI] [1.12s] Exiting LLM run with output:\n",
"\u001b[0m{\n",
" \"generations\": [\n",
" [\n",
" {\n",
" \"text\": \"\\nHey Tom, we have something special for you this week! Our master chefs have created a delicious Italian-Mexican fusion Chicken Flatbreads with red sauce just for you. Our chefs have also taken into account your preference of vegetarian options with regular dairy - this one is sure to be a hit!\",\n",
" \"generation_info\": {\n",
" \"finish_reason\": \"stop\",\n",
" \"logprobs\": null\n",
" }\n",
" }\n",
" ]\n",
" ],\n",
" \"llm_output\": {\n",
" \"token_usage\": {\n",
" \"total_tokens\": 154,\n",
" \"completion_tokens\": 61,\n",
" \"prompt_tokens\": 93\n",
" },\n",
" \"model_name\": \"text-davinci-003\"\n",
" },\n",
" \"run\": null\n",
"}\n",
"\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:PickBest > 2:chain:LLMChain] [1.12s] Exiting Chain run with output:\n",
"\u001b[0m{\n",
" \"text\": \"\\nHey Tom, we have something special for you this week! Our master chefs have created a delicious Italian-Mexican fusion Chicken Flatbreads with red sauce just for you. Our chefs have also taken into account your preference of vegetarian options with regular dairy - this one is sure to be a hit!\"\n",
"}\n",
"\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:LLMChain] Entering Chain run with input:\n",
"\u001b[0m[inputs]\n",
"\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:LLMChain > 2:llm:OpenAI] Entering LLM run with input:\n",
" \"Given ['Vegetarian', 'regular dairy is ok'] rank how good or bad this selection is ['Beef Enchiladas with Feta cheese. Mexican-Greek fusion', 'Chicken Flatbreads with red sauce. Italian-Mexican fusion', 'Veggie sweet potato quesadillas with vegan cheese', 'One-Pan Tortelonni bake with peppers and onions']\\n\\nIMPORTANT: you MUST return a single number between -1 and 1, -1 being bad, 1 being good\"\n",
"\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:LLMChain > 2:llm:OpenAI] [274ms] Exiting LLM run with output:\n",
"\u001b[0m{\n",
" \"generations\": [\n",
" [\n",
" {\n",
" \"text\": \"\\n0.625\",\n",
" \"generation_info\": {\n",
" \"finish_reason\": \"stop\",\n",
" \"logprobs\": null\n",
" }\n",
" }\n",
" ]\n",
" ],\n",
" \"llm_output\": {\n",
" \"token_usage\": {\n",
" \"total_tokens\": 112,\n",
" \"completion_tokens\": 4,\n",
" \"prompt_tokens\": 108\n",
" },\n",
" \"model_name\": \"text-davinci-003\"\n",
" },\n",
" \"run\": null\n",
"}\n",
"\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:LLMChain] [275ms] Exiting Chain run with output:\n",
"\u001b[0m{\n",
" \"text\": \"\\n0.625\"\n",
"}\n",
"\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:PickBest] [1.40s] Exiting Chain run with output:\n",
"\u001b[0m[outputs]\n"
]
},
{
"data": {
"text/plain": [
"{'response': 'Hey Tom, we have something special for you this week! Our master chefs have created a delicious Italian-Mexican fusion Chicken Flatbreads with red sauce just for you. Our chefs have also taken into account your preference of vegetarian options with regular dairy - this one is sure to be a hit!',\n",
" 'selection_metadata': <langchain_experimental.rl_chain.pick_best_chain.PickBestEvent at 0x289764220>}"