mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
run notebook and change location
This commit is contained in:
parent
62cf108700
commit
f1d144cd6c
File diff suppressed because one or more lines are too long
@ -1,646 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# Learned prompt variable injection via rl chain\n",
|
|
||||||
"\n",
|
|
||||||
"The rl_chain is used primarily for prompt variable injection: when we want to enhance a prompt with a value but we are not sure which of the available variable values will make the prompt achieve what we want.\n",
|
|
||||||
"\n",
|
|
||||||
"It provides a way to learn a specific prompt engineering policy without fine tuning the underlying foundational model.\n",
|
|
||||||
"\n",
|
|
||||||
"The example layed out below is trivial and a strong llm could make a good variable selection and injection without the intervention of this chain, but it is perfect for showcasing the chain's usage. Advanced options and explanations are provided at the end.\n",
|
|
||||||
"\n",
|
|
||||||
"The goal of the below scenario is for the chain to select a meal based on the user declared preferences, and inject the meal into the prompt template. The final prompt will then be sent to the llm of choice and the llm output will be returned to the user."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# four meals defined, some vegetarian some not\n",
|
|
||||||
"\n",
|
|
||||||
"meals = [\n",
|
|
||||||
" \"Beef Enchiladas with Feta cheese. Mexican-Greek fusion\",\n",
|
|
||||||
" \"Chicken Flatbreads with red sauce. Italian-Mexican fusion\",\n",
|
|
||||||
" \"Veggie sweet potato quesadillas with vegan cheese\",\n",
|
|
||||||
" \"One-Pan Tortelonni bake with peppers and onions\",\n",
|
|
||||||
"]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# pick and configure the LLM of your choice\n",
|
|
||||||
"\n",
|
|
||||||
"from langchain.llms import OpenAI\n",
|
|
||||||
"llm = OpenAI(engine=\"text-davinci-003\")\n",
|
|
||||||
"\n",
|
|
||||||
"llm.predict(\"are you ready?\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"##### Intialize the rl chain with provided defaults\n",
|
|
||||||
"\n",
|
|
||||||
"The prompt template which will be used to query the LLM needs to be defined.\n",
|
|
||||||
"It can be anything, but here `{meal}` is being used and is going to be replaced by one of the meals above, the rl chain will try to pick and inject the best meal\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from langchain.prompts import PromptTemplate\n",
|
|
||||||
"\n",
|
|
||||||
"# here I am using the variable meal which will be replaced by one of the meals above\n",
|
|
||||||
"# and some variables like user, preference, and text_to_personalize which I will provide at chain run time\n",
|
|
||||||
"\n",
|
|
||||||
"PROMPT_TEMPLATE = \"\"\"Here is the description of a meal: \"{meal}\".\n",
|
|
||||||
"\n",
|
|
||||||
"Embed the meal into the given text: \"{text_to_personalize}\".\n",
|
|
||||||
"\n",
|
|
||||||
"Prepend a personalized message including the user's name {user} and their preference {preference}.\n",
|
|
||||||
"\n",
|
|
||||||
"Make it sound good.\n",
|
|
||||||
"\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"PROMPT = PromptTemplate(\n",
|
|
||||||
" input_variables=[\"meal\", \"text_to_personalize\", \"user\", \"preference\"], template=PROMPT_TEMPLATE\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Next the rl chain's PickBest chain is being initialized. We must provide the llm of choice and the defined prompt. As the name indicates, the chain's goal is to Pick the Best of the meals that will be provided, based on some criteria. "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import langchain.chains.rl_chain as rl_chain\n",
|
|
||||||
"\n",
|
|
||||||
"chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Once the chain is setup I am going to call it with the meals I want to be selected from, and some context based on which the chain will select a meal."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"response = chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(response[\"response\"])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## What is the chain doing\n",
|
|
||||||
"\n",
|
|
||||||
"What is happening behind the scenes here is that the rl chain will\n",
|
|
||||||
"\n",
|
|
||||||
"- take the meals\n",
|
|
||||||
"- take the user and their preference\n",
|
|
||||||
"- based on the user and their preference (context) it will select a meal\n",
|
|
||||||
"- it will auto-evaluate if that meal selection was good or bad\n",
|
|
||||||
"- it will finally inject the meal into the prompt and query the llm\n",
|
|
||||||
"- the user will get the llm response back\n",
|
|
||||||
"\n",
|
|
||||||
"Now, the way the chain is doing this is that it is learning a contextual bandit rl model that is trained to make good selections (specifially the [VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit) ML library is being used).\n",
|
|
||||||
"\n",
|
|
||||||
"Since this rl model will be untrained when we first start, it might make a random selection that doesn't fit the user and their preferences. But if we give it time to learn the user and their preferences, it should start to make better selections (or quickly learn a good one and just pick that!)."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"for _ in range(5):\n",
|
|
||||||
" try:\n",
|
|
||||||
" response = chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
" )\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" print(e)\n",
|
|
||||||
" print(response[\"response\"])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## How is the chain learning\n",
|
|
||||||
"\n",
|
|
||||||
"The way the chain is learning that Tom prefers veggetarian meals is via an AutoSelectionScorer that is built into the chain. The scorer will call the LLM again and ask it to evaluate the selection (`ToSelectFrom`) using the information wrapped in (`BasedOn`).\n",
|
|
||||||
"\n",
|
|
||||||
"You can set `langchain.debug=True` if you want to see the details of the auto-scorer, but you can also define the scoring prompt yourself."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"scoring_criteria_template = \"Given {preference} rank how good or bad this selection is {meal}\"\n",
|
|
||||||
"\n",
|
|
||||||
"chain = rl_chain.PickBest.from_llm(\n",
|
|
||||||
" llm=llm,\n",
|
|
||||||
" prompt=PROMPT,\n",
|
|
||||||
" selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, scoring_criteria_template_str=scoring_criteria_template),\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"If you want to examine the score and other selection metadata you can by examining the metadata object returned by the chain"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"response = chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
")\n",
|
|
||||||
"print(response[\"response\"])\n",
|
|
||||||
"selection_metadata = response[\"selection_metadata\"]\n",
|
|
||||||
"print(f\"selected index: {selection_metadata.selected.index}, score: {selection_metadata.selected.score}\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"In a more realistic scenario it is likely that you have a well defined scoring function for what was selected. For example, you might be doing few-shot prompting and want to select prompt examples for a natural language to sql translation task. In that case the scorer could be: did the sql that was generated run in an sql engine? In that case you want to plugin a scoring function. In the example below I will just check if the meal picked was vegetarian or not."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"class CustomSelectionScorer(rl_chain.SelectionScorer):\n",
|
|
||||||
" def score_response(\n",
|
|
||||||
" self, inputs, llm_response: str, event: rl_chain.PickBestEvent) -> float:\n",
|
|
||||||
"\n",
|
|
||||||
" print(event.based_on)\n",
|
|
||||||
" print(event.to_select_from)\n",
|
|
||||||
"\n",
|
|
||||||
" # you can build a complex scoring function here\n",
|
|
||||||
" # it is prefereable that the score ranges between 0 and 1 but it is not enforced\n",
|
|
||||||
"\n",
|
|
||||||
" selected_meal = event.to_select_from[\"meal\"][event.selected.index]\n",
|
|
||||||
" print(f\"selected meal: {selected_meal}\")\n",
|
|
||||||
"\n",
|
|
||||||
" if \"Tom\" in event.based_on[\"user\"]:\n",
|
|
||||||
" if \"Vegetarian\" in event.based_on[\"preference\"]:\n",
|
|
||||||
" if \"Chicken\" in selected_meal or \"Beef\" in selected_meal:\n",
|
|
||||||
" return 0.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" return 1.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" if \"Chicken\" in selected_meal or \"Beef\" in selected_meal:\n",
|
|
||||||
" return 1.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" return 0.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" raise NotImplementedError(\"I don't know how to score this user\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"chain = rl_chain.PickBest.from_llm(\n",
|
|
||||||
" llm=llm,\n",
|
|
||||||
" prompt=PROMPT,\n",
|
|
||||||
" selection_scorer=CustomSelectionScorer(),\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"response = chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## How can I track the chains progress\n",
|
|
||||||
"\n",
|
|
||||||
"You can track the chains progress by using the metrics mechanism provided. I am going to expand the users to Tom and Anna, and extend the scoring function. I am going to initialize two chains, one with the default learning policy and one with a built-in random policy (i.e. selects a meal randomly), and plot their scoring progress."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"class CustomSelectionScorer(rl_chain.SelectionScorer):\n",
|
|
||||||
" def score_preference(self, preference, selected_meal):\n",
|
|
||||||
" if \"Vegetarian\" in preference:\n",
|
|
||||||
" if \"Chicken\" in selected_meal or \"Beef\" in selected_meal:\n",
|
|
||||||
" return 0.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" return 1.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" if \"Chicken\" in selected_meal or \"Beef\" in selected_meal:\n",
|
|
||||||
" return 1.0\n",
|
|
||||||
" else:\n",
|
|
||||||
" return 0.0\n",
|
|
||||||
" def score_response(\n",
|
|
||||||
" self, inputs, llm_response: str, event: rl_chain.PickBestEvent) -> float:\n",
|
|
||||||
"\n",
|
|
||||||
" selected_meal = event.to_select_from[\"meal\"][event.selected.index]\n",
|
|
||||||
"\n",
|
|
||||||
" if \"Tom\" in event.based_on[\"user\"]:\n",
|
|
||||||
" return self.score_preference(event.based_on[\"preference\"], selected_meal)\n",
|
|
||||||
" elif \"Anna\" in event.based_on[\"user\"]:\n",
|
|
||||||
" return self.score_preference(event.based_on[\"preference\"], selected_meal)\n",
|
|
||||||
" else:\n",
|
|
||||||
" raise NotImplementedError(\"I don't know how to score this user\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"chain = rl_chain.PickBest.from_llm(\n",
|
|
||||||
" llm=llm,\n",
|
|
||||||
" prompt=PROMPT,\n",
|
|
||||||
" selection_scorer=CustomSelectionScorer(),\n",
|
|
||||||
" metrics_step=5,\n",
|
|
||||||
" metrics_window_size=5, # rolling window average\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"random_chain = rl_chain.PickBest.from_llm(\n",
|
|
||||||
" llm=llm,\n",
|
|
||||||
" prompt=PROMPT,\n",
|
|
||||||
" selection_scorer=CustomSelectionScorer(),\n",
|
|
||||||
" metrics_step=5,\n",
|
|
||||||
" metrics_window_size=5, # rolling window average\n",
|
|
||||||
" policy=rl_chain.PickBestRandomPolicy # set the random policy instead of default\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"for i in range(40):\n",
|
|
||||||
" try:\n",
|
|
||||||
" if i % 2:\n",
|
|
||||||
" chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
" )\n",
|
|
||||||
" random_chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
" )\n",
|
|
||||||
" else:\n",
|
|
||||||
" chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Anna\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
" )\n",
|
|
||||||
" random_chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Anna\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
" )\n",
|
|
||||||
" except Exception as e:\n",
|
|
||||||
" print(e)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# note matplotlib is not a dependency of langchain so you need to install it\n",
|
|
||||||
"\n",
|
|
||||||
"from matplotlib import pyplot as plt\n",
|
|
||||||
"chain.metrics.to_pandas()['score'].plot(label=\"default learning policy\")\n",
|
|
||||||
"random_chain.metrics.to_pandas()['score'].plot(label=\"random selection policy\")\n",
|
|
||||||
"plt.legend()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"There is a bit of randomness involved in the rl_chain's selection since the chain explores the selection space in order to learn the world as best as it can (see details of default exploration algorithm used [here](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-Exploration-with-SquareCB)), but overall, default chain policy should be doing better than random as it learns"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Advanced options\n",
|
|
||||||
"\n",
|
|
||||||
"The rl chain is highly configurable in order to be able to adjust to various selection scenarios. If you want to learn more about the ML library that powers it please take a look at tutorials [here](https://vowpalwabbit.org/)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"| Section | Description | Example / Usage |\n",
|
|
||||||
"|---------|-------------|-----------------|\n",
|
|
||||||
"| [**Set Chain Logging Level**](#set-chain-logging-level) | Set up the logging level for the RL chain. | `logger.setLevel(logging.INFO)` |\n",
|
|
||||||
"| [**Featurization**](#featurization) | Adjusts the input to the RL chain. Can set auto-embeddings ON for more complex embeddings. | `chain = rl_chain.PickBest.from_llm(auto_embed=True, [...])` |\n",
|
|
||||||
"| [**Learned Policy to Learn Asynchronously**](#learned-policy-to-learn-asynchronously) | Score asynchronously if user input is needed for scoring. | `chain.update_with_delayed_score(score=<the score>, chain_response=response)` |\n",
|
|
||||||
"| [**Store Progress of Learned Policy**](#store-progress-of-learned-policy) | Option to store the progress of the variable injection learned policy. | `chain.save_progress()` |\n",
|
|
||||||
"| [**Stop Learning of Learned Policy**](#stop-learning-of-learned-policy) | Toggle the RL chain's learned policy updates ON/OFF. | `chain.deactivate_selection_scorer()` |\n",
|
|
||||||
"| [**Set a Different Policy**](#set-a-different-policy) | Choose between different policies: default, random, or custom. | Custom policy creation at chain creation time. |\n",
|
|
||||||
"| [**Different Exploration Algorithms for Default Learned Policy**](#different-exploration-algorithms-for-the-default-learned-policy) | Set different exploration algorithms and hyperparameters for `VwPolicy`. | `vw_cmd = [\"--cb_explore_adf\", \"--quiet\", \"--squarecb\", \"--interactions=::\"]` |\n",
|
|
||||||
"| [**Learn Policy's Data Logs**](#learn-policys-data-logs) | Store and examine `VwPolicy`'s data logs. | `chain = rl_chain.PickBest.from_llm(vw_logs=<path to log FILE>, [...])` |\n",
|
|
||||||
"| [**Other Advanced Featurization Options**](#other-advanced-featurization-options) | Specify advanced featurization options for the RL chain. | `age = rl_chain.BasedOn(\"age:32\")` |\n",
|
|
||||||
"| [**More Info on Auto or Custom SelectionScorer**](#more-info-on-auto-or-custom-selectionscorer) | Dive deeper into how selection scoring is determined. | `selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, scoring_criteria_template_str=scoring_criteria_template)` |\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"### set chain logging level\n",
|
|
||||||
"\n",
|
|
||||||
"```\n",
|
|
||||||
"import logging\n",
|
|
||||||
"logger = logging.getLogger(\"rl_chain\")\n",
|
|
||||||
"logger.setLevel(logging.INFO)\n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"### featurization\n",
|
|
||||||
"\n",
|
|
||||||
"By default the input to the rl chain (`ToSelectFrom`, `BasedOn`) is not tampered with. This might not be sufficient featurization, so based on how complex the scenario is you can set auto-embeddings to ON\n",
|
|
||||||
"\n",
|
|
||||||
"`chain = rl_chain.PickBest.from_llm(auto_embed=True, [...])`\n",
|
|
||||||
"\n",
|
|
||||||
"This will produce more complex embeddings and featurizations of the inputs, likely accelerating RL chain learning, albeit at the cost of increased runtime..\n",
|
|
||||||
"\n",
|
|
||||||
"By default, [sbert.net's sentence_transformers's ](https://www.sbert.net/docs/pretrained_models.html#model-overview) `all-mpnet-base-v2` model will be used for these embeddings but you can set a different model by initializing the chain with it, or set an entirely different encoding object as long as it has an `encode` function that returns a list of the encodings:\n",
|
|
||||||
"\n",
|
|
||||||
"```\n",
|
|
||||||
"from sentence_transformers import SentenceTransformer\n",
|
|
||||||
"\n",
|
|
||||||
"chain = rl_chain.PickBest.from_llm(\n",
|
|
||||||
" [...]\n",
|
|
||||||
" feature_embedder=rl_chain.PickBestFeatureEmbedder(\n",
|
|
||||||
" auto_embed=True,\n",
|
|
||||||
" model=SentenceTransformer(\"all-mpnet-base-v2\")\n",
|
|
||||||
" )\n",
|
|
||||||
")\n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"Another option is to define what inputs you think should be embedded manually:\n",
|
|
||||||
"- `auto_embed = False`\n",
|
|
||||||
"- Can wrap individual variables in `rl_chain.Embed()` or `rl_chain.EmbedAndKeep()` e.g. `user = rl_chain.BasedOn(rl_chain.Embed(\"Tom\"))`\n",
|
|
||||||
"\n",
|
|
||||||
"Final option is to define and set your own feature embedder that returns a valid input for the learned policy.\n",
|
|
||||||
"\n",
|
|
||||||
"## learned policy to learn asynchronously\n",
|
|
||||||
"\n",
|
|
||||||
"If to score the result you need input from the user (e.g. my application showed Tom the selected meal and Tom clicked on it, but Anna did not), then the scoring can be done asynchronously. The way to do that is:\n",
|
|
||||||
"\n",
|
|
||||||
"- set `selection_scorer=None` on the chain creation OR call `chain.deactivate_selection_scorer()`\n",
|
|
||||||
"- call the chain for a specific input\n",
|
|
||||||
"- keep the chain's response (`response = chain.run([...])`)\n",
|
|
||||||
"- once you have determined the score of the response/chain selection call the chain with it: `chain.update_with_delayed_score(score=<the score>, chain_response=response)`\n",
|
|
||||||
"\n",
|
|
||||||
"### store progress of learned policy\n",
|
|
||||||
"\n",
|
|
||||||
"Since the variable injection learned policy evolves over time, there is the option to store its progress and continue learning. This can be done by calling:\n",
|
|
||||||
"\n",
|
|
||||||
"`chain.save_progress()`\n",
|
|
||||||
"\n",
|
|
||||||
"which will store the rl chain's learned policy in a file called `latest.vw`. It will also store it in a file with a timestamp. That way, if `save_progress()` is called more than once, multiple checkpoints will be created, but the latest one will always be in `latest.vw`\n",
|
|
||||||
"\n",
|
|
||||||
"Next time the chain is loaded, the chain will look for a file called `latest.vw` and if the file exists it will be loaded into the chain and the learning will continue from there.\n",
|
|
||||||
"\n",
|
|
||||||
"By default the rl chain model checkpoints will be stored in the current directory but you can specify the save/load location at chain creation time:\n",
|
|
||||||
"\n",
|
|
||||||
"`chain = rl_chain.PickBest.from_llm(model_save_dir=<path to dir>, [...])`\n",
|
|
||||||
"\n",
|
|
||||||
"### stop learning of learned policy\n",
|
|
||||||
"\n",
|
|
||||||
"If you want the rl chain's learned policy to stop updating you can turn it off/on:\n",
|
|
||||||
"\n",
|
|
||||||
"`chain.deactivate_selection_scorer()` and `chain.activate_selection_scorer()`\n",
|
|
||||||
"\n",
|
|
||||||
"### set a different policy\n",
|
|
||||||
"\n",
|
|
||||||
"There are two policies currently available:\n",
|
|
||||||
"\n",
|
|
||||||
"- default policy: `VwPolicy` which learns a [Vowpal Wabbit](https://github.com/VowpalWabbit/vowpal_wabbit) [Contextual Bandit](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-algorithms) model\n",
|
|
||||||
"\n",
|
|
||||||
"- random policy: `RandomPolicy` which doesn't learn anything and just selects a value randomly. this policy can be used to compare other policies with a random baseline one.\n",
|
|
||||||
"\n",
|
|
||||||
"- custom policies: a custom policy could be created and set at chain creation time\n",
|
|
||||||
"\n",
|
|
||||||
"### different exploration algorithms for the default learned policy\n",
|
|
||||||
"\n",
|
|
||||||
"The default `VwPolicy` is initialized with some default arguments. The default exploration algorithm is [SquareCB](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Contextual-Bandit-Exploration-with-SquareCB) but other Contextual Bandit exploration algorithms can be set, and other hyper parameters can be set also:\n",
|
|
||||||
"\n",
|
|
||||||
"`vw_cmd = [\"--cb_explore_adf\", \"--quiet\", \"--squarecb\", \"--interactions=::\"]`\n",
|
|
||||||
"\n",
|
|
||||||
"`chain = rl_chain.PickBest.from_llm(vw_cmd = vw_cmd, [...])`\n",
|
|
||||||
"\n",
|
|
||||||
"### learn policy's data logs\n",
|
|
||||||
"\n",
|
|
||||||
"The `VwPolicy`'s data files can be stored and examined or used to do [off policy evaluation](https://vowpalwabbit.org/docs/vowpal_wabbit/python/latest/tutorials/off_policy_evaluation.html) for hyper parameter tuning.\n",
|
|
||||||
"\n",
|
|
||||||
"The way to do this is to set a log file path to `vw_logs` on chain creation:\n",
|
|
||||||
"\n",
|
|
||||||
"`chain = rl_chain.PickBest.from_llm(vw_logs=<path to log FILE>, [...])`\n",
|
|
||||||
"\n",
|
|
||||||
"### other advanced featurization options\n",
|
|
||||||
"\n",
|
|
||||||
"Explictly numerical features can be provided with a colon separator:\n",
|
|
||||||
"`age = rl_chain.BasedOn(\"age:32\")`\n",
|
|
||||||
"\n",
|
|
||||||
"`ToSelectFrom` can be a bit more complex if the scenario demands it, instead of being a list of strings it can be:\n",
|
|
||||||
"- a list of list of strings:\n",
|
|
||||||
" ```\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom([\n",
|
|
||||||
" [\"meal 1 name\", \"meal 1 description\"],\n",
|
|
||||||
" [\"meal 2 name\", \"meal 2 description\"]\n",
|
|
||||||
" ])\n",
|
|
||||||
" ```\n",
|
|
||||||
"- a list of dictionaries:\n",
|
|
||||||
" ```\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom([\n",
|
|
||||||
" {\"name\":\"meal 1 name\", \"description\" : \"meal 1 description\"},\n",
|
|
||||||
" {\"name\":\"meal 2 name\", \"description\" : \"meal 2 description\"}\n",
|
|
||||||
" ])\n",
|
|
||||||
" ```\n",
|
|
||||||
"- a list of dictionaries containing lists:\n",
|
|
||||||
" ```\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom([\n",
|
|
||||||
" {\"name\":[\"meal 1\", \"complex name\"], \"description\" : \"meal 1 description\"},\n",
|
|
||||||
" {\"name\":[\"meal 2\", \"complex name\"], \"description\" : \"meal 2 description\"}\n",
|
|
||||||
" ])\n",
|
|
||||||
" ```\n",
|
|
||||||
"\n",
|
|
||||||
"`BasedOn` can also take a list of strings:\n",
|
|
||||||
"```\n",
|
|
||||||
"user = rl_chain.BasedOn([\"Tom Joe\", \"age:32\", \"state of california\"])\n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"there is no dictionary provided since multiple variables can be supplied wrapped in `BasedOn`\n",
|
|
||||||
"\n",
|
|
||||||
"Storing the data logs into a file allows the examination of what different inputs do to the data format.\n",
|
|
||||||
"\n",
|
|
||||||
"### More info on Auto or Custom SelectionScorer\n",
|
|
||||||
"\n",
|
|
||||||
"The selection scorer is very important to get right since the policy uses it to learn. It determines what is called the reward in reinforcement learning, and more specifically in our Contextual Bandits setting.\n",
|
|
||||||
"\n",
|
|
||||||
"The general advice is to keep the score between [0, 1], 0 being the worst selection, 1 being the best selection from the available `ToSelectFrom` variables, based on the `BasedOn` variables, but should be adjusted if the need arises.\n",
|
|
||||||
"\n",
|
|
||||||
"In the examples provided above, the AutoSelectionScorer is set mostly to get users started but in real world scenarios it will most likely not be an adequate scorer function.\n",
|
|
||||||
"\n",
|
|
||||||
"The example also provided the option to change part of the scoring prompt template that the AutoSelectionScorer used to determine whether a selection was good or not:\n",
|
|
||||||
"\n",
|
|
||||||
"```\n",
|
|
||||||
"scoring_criteria_template = \"Given {preference} rank how good or bad this selection is {meal}\"\n",
|
|
||||||
"chain = rl_chain.PickBest.from_llm(\n",
|
|
||||||
" llm=llm,\n",
|
|
||||||
" prompt=PROMPT,\n",
|
|
||||||
" selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, scoring_criteria_template_str=scoring_criteria_template),\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"```\n",
|
|
||||||
"\n",
|
|
||||||
"Internally the AutoSelectionScorer adjusted the scoring prompt to make sure that the llm scoring retured a single float.\n",
|
|
||||||
"\n",
|
|
||||||
"However, if needed, a FULL scoring prompt can also be provided:\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
|
||||||
"import langchain\n",
|
|
||||||
"langchain.debug = True\n",
|
|
||||||
"\n",
|
|
||||||
"REWARD_PROMPT_TEMPLATE = \"\"\"Given {preference} rank how good or bad this selection is {meal}, IMPORANT: you MUST return a single number between -1 and 1, -1 being bad, 1 being good\"\"\"\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"REWARD_PROMPT = PromptTemplate(\n",
|
|
||||||
" input_variables=[\"preference\", \"meal\"],\n",
|
|
||||||
" template=REWARD_PROMPT_TEMPLATE,\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"chain = rl_chain.PickBest.from_llm(\n",
|
|
||||||
" llm=llm,\n",
|
|
||||||
" prompt=PROMPT,\n",
|
|
||||||
" selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, prompt=REWARD_PROMPT),\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"chain.run(\n",
|
|
||||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
|
||||||
" user = rl_chain.BasedOn(\"Tom\"),\n",
|
|
||||||
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
|
|
||||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
|
||||||
")\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "venv",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.8.16"
|
|
||||||
},
|
|
||||||
"orig_nbformat": 4
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user