mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
keep only what is needed for first PR
This commit is contained in:
parent
6de1ca4251
commit
56b40beb0e
@ -1,23 +0,0 @@
|
||||
name: Unit Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
python-unit-test:
|
||||
container:
|
||||
image: python:3.8
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- name: Run Tests
|
||||
shell: bash
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
pip install pytest
|
||||
python -m pytest tests/
|
@ -1,6 +0,0 @@
|
||||
**/__pycache__/**
|
||||
models/*
|
||||
logs/*
|
||||
**/*.vw
|
||||
.venv
|
||||
|
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Vowpal Wabbit
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,25 +0,0 @@
|
||||
# VW in a langchain chain
|
||||
|
||||
Install `requirements.txt`
|
||||
|
||||
[VowpalWabbit](https://github.com/VowpalWabbit/vowpal_wabbit)
|
||||
|
||||
There is an example notebook (rl_chain.ipynb) with basic usage of the chain.
|
||||
|
||||
TLDR:
|
||||
|
||||
- Chain is initialized and creates a Vowpal Wabbit instance - only Contextual Bandits and Slates are supported for now
|
||||
- You can change the arguments at chain creation time
|
||||
- There is a default prompt but it can be changed
|
||||
- There is a default reward function that gets triggered and triggers learn automatically
|
||||
- This can be turned off and score can be spcified explicitly
|
||||
|
||||
Flow:
|
||||
|
||||
- Developer: creates chain
|
||||
- Developer: sets actions
|
||||
- Developer: calls chain with context and other prompt inputs
|
||||
- Chain: calls VW with the context and selects an action
|
||||
- Chain: action (and other vars) are passed to the LLM with the prompt
|
||||
- Chain: if default reward set, the LLM is called to judge and give a reward score of the response based on the context
|
||||
- Chain: VW learn is triggered with that score
|
@ -1,364 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Prepare core llm chain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import langchain\n",
|
||||
"langchain.debug = False # set to True if you want to see what the LLM is doing\n",
|
||||
"\n",
|
||||
"from langchain.chat_models import AzureChatOpenAI\n",
|
||||
"\n",
|
||||
"import dotenv\n",
|
||||
"dotenv.load_dotenv()\n",
|
||||
"\n",
|
||||
"llm = AzureChatOpenAI(\n",
|
||||
" deployment_name=\"gpt-35-turbo\",\n",
|
||||
" temperature=0,\n",
|
||||
" request_timeout=20,\n",
|
||||
" max_retries=1,\n",
|
||||
" client=None,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"llm.predict('Are you ready?')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Vanilla LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains.llm import LLMChain\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"llm_chain = LLMChain(\n",
|
||||
" llm = llm,\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"adjective\", \"content\", \"topic\"],\n",
|
||||
" template=\"Hi, please create {adjective} {content} about {topic}.\",\n",
|
||||
" ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain.run(\n",
|
||||
" adjective = \"funny\",\n",
|
||||
" content = \"poem\",\n",
|
||||
" topic = \"machine learning\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Variable selection"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"llm_chain = rl_chain.SlatesPersonalizerChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"adjective\", \"content\", \"topic\"],\n",
|
||||
" template=\"Hi, please create {adjective} {content} about {topic}\",\n",
|
||||
" ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"r = llm_chain.run(\n",
|
||||
" adjective = rl_chain.ToSelectFrom([\"funny\"]),\n",
|
||||
" content = rl_chain.ToSelectFrom([\"poem\"]),\n",
|
||||
" topic = rl_chain.ToSelectFrom([\"machine learning\"]))\n",
|
||||
"\n",
|
||||
"print(r[\"response\"])\n",
|
||||
"print(r[\"selection_metadata\"].to_select_from)\n",
|
||||
"print(r[\"selection_metadata\"].based_on)\n",
|
||||
"print(r[\"selection_metadata\"].selected.score)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain.update_with_delayed_score(score=1, event=r[\"selection_metadata\"], force_score=True)\n",
|
||||
"print(r[\"selection_metadata\"].selected.score)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"It is ok to be uncertain about certain variable values"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain.run(\n",
|
||||
" adjective = rl_chain.ToSelectFrom([\"funny\", \"scary\"]),\n",
|
||||
" content = rl_chain.ToSelectFrom([\"poem\"]),\n",
|
||||
" topic = rl_chain.ToSelectFrom([\"machine learning\", \"cats\"]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Full loop"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"from langchain.prompts import (\n",
|
||||
" ChatPromptTemplate,\n",
|
||||
" HumanMessagePromptTemplate,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"template = \"\"\"\n",
|
||||
"using style {style}\n",
|
||||
"\n",
|
||||
"{prefix}\n",
|
||||
"{goal}: {context}.\n",
|
||||
"{suffix}\n",
|
||||
"\"\"\"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"prefix\", \"goal\", \"context\", \"suffix\", \"style\"],\n",
|
||||
" template=template,\n",
|
||||
")\n",
|
||||
"chain = rl_chain.SlatesPersonalizerChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" vw_logs = 'logs/stories.txt',\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" prompt = prompt,\n",
|
||||
" selection_scorer = rl_chain.AutoSelectionScorer(\n",
|
||||
" llm=llm,\n",
|
||||
" scoring_criteria_template_str = '''Given the task:\n",
|
||||
" {goal}: {context}\n",
|
||||
" rank how good or bad this response is:\n",
|
||||
" {llm_response}.''',\n",
|
||||
" ),\n",
|
||||
" metrics_step=1\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain.run(\n",
|
||||
" prefix = rl_chain.ToSelectFrom([f'ALWAYS DO EXACTLY WHAT I ASK YOU!', 'Please do your best to help me.']),\n",
|
||||
" goal = rl_chain.ToSelectFrom(['Write a funny story about']),\n",
|
||||
" context = rl_chain.ToSelectFrom(['Friends series']),\n",
|
||||
" suffix = rl_chain.ToSelectFrom(['Please try to be as funny as possible.', '']),\n",
|
||||
" style = \"Shakespeare\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"template = \"\"\"\n",
|
||||
"{prefix}\n",
|
||||
"{goal}: {context}.\n",
|
||||
"{suffix}\n",
|
||||
"\"\"\"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"prefix\", \"goal\", \"context\", \"suffix\"],\n",
|
||||
" template=template,\n",
|
||||
")\n",
|
||||
"chain = rl_chain.SlatesPersonalizerChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" vw_logs = 'logs/stories.txt',\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" prompt = prompt,\n",
|
||||
" selection_scorer = rl_chain.AutoSelectionScorer(\n",
|
||||
" llm=llm,\n",
|
||||
" scoring_criteria_template_str = '''Given the task:\n",
|
||||
" {goal}: {context}\n",
|
||||
" rank how good or bad this response is:\n",
|
||||
" {llm_response}.'''\n",
|
||||
" ),\n",
|
||||
" metrics_step=1\n",
|
||||
")\n",
|
||||
"chain.run(\n",
|
||||
" prefix = rl_chain.ToSelectFrom(rl_chain.Embed([f'ALWAYS DO EXACTLY WHAT I ASK YOU!', 'Please do your best to help me.'])),\n",
|
||||
" goal = rl_chain.ToSelectFrom([rl_chain.Embed('Write a funny story about')]),\n",
|
||||
" context = rl_chain.ToSelectFrom(['Friends series']),\n",
|
||||
" suffix = rl_chain.ToSelectFrom(['Please try to be as funny as possible.', '']))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Experiment with mock llm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List\n",
|
||||
"from tests.test_utils import MockScorer\n",
|
||||
"\n",
|
||||
"class MockLLMChain:\n",
|
||||
" outcomes: List[List[float]] = None\n",
|
||||
" \n",
|
||||
" def __init__(self, outcomes, prompt):\n",
|
||||
" self.outcomes = outcomes\n",
|
||||
" self.prompt = prompt\n",
|
||||
"\n",
|
||||
" def run(self, prefix, suffix, **kwargs):\n",
|
||||
" return str(self.outcomes[int(prefix)][int(suffix)])\n",
|
||||
"\n",
|
||||
"import rl_chain\n",
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"template = \"\"\"\n",
|
||||
"{prefix}\n",
|
||||
"{context}\n",
|
||||
"{suffix}\n",
|
||||
"\"\"\"\n",
|
||||
"prompt = PromptTemplate(\n",
|
||||
" input_variables=[\"prefix\", \"context\", \"suffix\"],\n",
|
||||
" template=template,\n",
|
||||
")\n",
|
||||
"chain = rl_chain.SlatesPersonalizerChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" vw_logs = 'logs/mock.txt',\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" prompt = prompt,\n",
|
||||
" selection_scorer = MockScorer(),\n",
|
||||
" metrics_step=1\n",
|
||||
")\n",
|
||||
"chain.llm_chain = MockLLMChain([\n",
|
||||
" [0, 0.3],\n",
|
||||
" [0.6, 0.9]], prompt = prompt)\n",
|
||||
"chain.run(\n",
|
||||
" prefix = rl_chain.ToSelectFrom(['0', '1']),\n",
|
||||
" context = rl_chain.ToSelectFrom(['bla']),\n",
|
||||
" suffix = rl_chain.ToSelectFrom(['0', '1']))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"\n",
|
||||
"vw_chain = rl_chain.SlatesPersonalizerChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" vw_logs = 'logs/mock.txt',\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" prompt = prompt,\n",
|
||||
" policy = rl_chain.VwPolicy,\n",
|
||||
" selection_scorer = MockScorer(),\n",
|
||||
" auto_embed=False,\n",
|
||||
" metrics_step=1\n",
|
||||
")\n",
|
||||
"vw_chain.llm_chain = MockLLMChain([\n",
|
||||
" [0, 0.3],\n",
|
||||
" [0.6, 0.9]], prompt = prompt)\n",
|
||||
"\n",
|
||||
"rnd_chain = rl_chain.SlatesPersonalizerChain.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" vw_logs = 'logs/mock.txt',\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" prompt = prompt,\n",
|
||||
" policy = rl_chain.SlatesRandomPolicy,\n",
|
||||
" selection_scorer = MockScorer(),\n",
|
||||
" auto_embed=False,\n",
|
||||
" metrics_step=1\n",
|
||||
")\n",
|
||||
"rnd_chain.llm_chain = MockLLMChain([\n",
|
||||
" [0, 0.3],\n",
|
||||
" [0.6, 0.9]], prompt = prompt)\n",
|
||||
"\n",
|
||||
"for i in range(1000):\n",
|
||||
" vw_chain.run(\n",
|
||||
" prefix = rl_chain.ToSelectFrom(['0', '1']),\n",
|
||||
" context = rl_chain.ToSelectFrom(['bla']),\n",
|
||||
" suffix = rl_chain.ToSelectFrom(['0']))\n",
|
||||
" rnd_chain.run(\n",
|
||||
" prefix = rl_chain.ToSelectFrom(['0', '1']),\n",
|
||||
" context = rl_chain.ToSelectFrom(['bla']),\n",
|
||||
" suffix = rl_chain.ToSelectFrom(['0']))\n",
|
||||
"\n",
|
||||
"vw_chain.metrics.to_pandas()['score'].plot(label=\"vw\")\n",
|
||||
"rnd_chain.metrics.to_pandas()['score'].plot(label=\"slates\")\n",
|
||||
"plt.legend()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,403 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MealPlanner:\n",
|
||||
" def __init__(self, name: str, desc: str, difficulty: str, tags: str):\n",
|
||||
" try:\n",
|
||||
" self.name = name\n",
|
||||
" self.desc = desc\n",
|
||||
" self.diff = difficulty\n",
|
||||
" self.tags = tags\n",
|
||||
" except:\n",
|
||||
" print(name)\n",
|
||||
" raise ValueError"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## Actions\n",
|
||||
"## examples copied from hello fresh website\n",
|
||||
"actions = [\n",
|
||||
" MealPlanner(name=\"One-Pan Beef Enchiladas Verdes with Mexican Cheese Blend & Hot Sauce Crema\", difficulty=\"Easy\", tags=\"Spicy, Easy Cleanup, Easy Prep\", desc=\"When it comes to Mexican-style cuisine, burritos typically get all the glory. In our humble opinion, enchiladas are an unsung dinner hero. They’re technically easier-to-assemble burritos that get smothered in a delicious sauce, but they’re really so much more than that! Ours start with spiced beef and charred green pepper that get rolled up in warm tortillas. This winning combo gets topped with tangy salsa verde and cheese, then baked until bubbly and melty. Hear that? That’s the sound of the dinner bell!\"),\n",
|
||||
" MealPlanner(name=\"Chicken & Mushroom Flatbreads with Gouda Cream Sauce & Parmesan\", difficulty=\"Easy\", tags=\"\", desc=\"Yes we love our simple cheese pizza with red sauce but tonight, move over, marinara—there’s a new sauce in town. In this recipe, crispy flatbreads are slathered with a rich, creamy gouda-mustard sauce we just can’t get enough of. We top that off with a pile of caramelized onion and earthy cremini mushrooms. Shower with Parmesan, and that’s it. Simple, satisfying, and all in 30 minutes–a dinner idea you can’t pass up!\"),\n",
|
||||
" MealPlanner(name=\"Sweet Potato & Pepper Quesadillas with Southwest Crema & Tomato Salsa\", difficulty=\"Easy\", tags=\"Veggie\", desc=\"This quesadilla is jam-packed with flavorful roasted sweet potato and green pepper, plus two types of gooey, melty cheese (how could we choose just one?!). Of course, we’d never forget the toppings—there’s a fresh tomato salsa and dollops of spiced lime crema. Now for the fun part: piling on a little bit of everything to construct the perfect bite!\"),\n",
|
||||
" MealPlanner(name=\"One-Pan Trattoria Tortelloni Bake with a Crispy Parmesan Panko Topping\", difficulty=\"Easy\", tags=\"Veggie, Easy Cleanup, Easy Prep\", desc=\"Think a cheesy stuffed pasta can’t get any better? What about baking it in a creamy sauce with a crispy topping? In this recipe, we toss cheese-stuffed tortelloni in an herby tomato cream sauce, then top with Parmesan and panko breadcrumbs. Once broiled, it turns into a showstopping topping that’ll earn you plenty of oohs and aahs from your lucky fellow diners.\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"meals = [f'title={action.name.replace(\":\", \"\").replace(\"|\", \"\")}' for action in actions]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chat_models import AzureChatOpenAI\n",
|
||||
"import langchain\n",
|
||||
"langchain.debug = False\n",
|
||||
"# assuming LLM api keys have been set in the environment\n",
|
||||
"# can use whatever LLM you want here doesn't have to be AzureChatOpenAI\n",
|
||||
"\n",
|
||||
"llm = AzureChatOpenAI(\n",
|
||||
" deployment_name=\"gpt-35-turbo\",\n",
|
||||
" temperature=0,\n",
|
||||
" request_timeout=10,\n",
|
||||
" max_retries=1,\n",
|
||||
" client=None,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"llm.predict('Are you ready?')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### default chain default reward (the LLM is used to judge and rank the response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"logger = logging.getLogger(\"rl_chain\")\n",
|
||||
"logger.setLevel(logging.INFO)\n",
|
||||
"\n",
|
||||
"_PROMPT_TEMPLATE = \"\"\"Here is the description of a meal: {meal}.\n",
|
||||
"\n",
|
||||
"You have to embed this into the given text where it makes sense. Here is the given text: {text_to_personalize}.\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"PROMPT = PromptTemplate(\n",
|
||||
" input_variables=[\"meal\", \"text_to_personalize\"], template=_PROMPT_TEMPLATE\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = chain.run(\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" User = rl_chain.BasedOn(\"Tom Hanks\"),\n",
|
||||
" preference = rl_chain.BasedOn(\"Vegetarian, regular dairy is ok\"),\n",
|
||||
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(response[\"response\"])\n",
|
||||
"rr = response[\"selection_metadata\"]\n",
|
||||
"print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"_OTHER_PROMPT_TEMPLATE = \"\"\"You can use the actions that were chosen by VW like so: {action}.\n",
|
||||
"\n",
|
||||
"And use whatever other vars you want to pass into the chain at run: {some_text}. And {some_other_text}\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"OTHER_PROMPT = PromptTemplate(\n",
|
||||
" input_variables=[\"action\", \"some_text\", \"some_other_text\"],\n",
|
||||
" template=_OTHER_PROMPT_TEMPLATE,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain.pick_best_chain\n",
|
||||
"\n",
|
||||
"chain = rl_chain.PickBest.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" prompt=OTHER_PROMPT,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = chain.run(\n",
|
||||
" some_text = \"This is some text\",\n",
|
||||
" some_other_text = \"This is some other text\",\n",
|
||||
" action=rl_chain.ToSelectFrom([\"an action\", \"another action\", \"a third action\"]),\n",
|
||||
" User = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn(\"Vegetarian\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(response[\"response\"])\n",
|
||||
"rr = response[\"selection_metadata\"]\n",
|
||||
"print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### actions and context with multiple namespaces"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# each action is a dictionary of namespace to action string\n",
|
||||
"# this example here shows that while embedding is recommended for all features, it is not required and can be customized\n",
|
||||
"action_strs_w_ns = [{\"A\":\"an action feature\", \"B\" : rl_chain.Embed(\"antoher action feature\")}, {\"B\": \"another action\"}, {\"C\":\"a third action\"}]\n",
|
||||
"\n",
|
||||
"inputs = {\n",
|
||||
" \"some_text\": \"This is some text\",\n",
|
||||
" \"some_other_text\": \"This is some other text\",\n",
|
||||
" \"action\" : rl_chain.ToSelectFrom(action_strs_w_ns)\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"inputs[\"User\"] = rl_chain.BasedOn(\"Tom\")\n",
|
||||
"inputs[\"preference\"] = rl_chain.BasedOn(rl_chain.Embed(\"Vegetarian\"))\n",
|
||||
"response = chain.run(inputs)\n",
|
||||
"print(response[\"response\"])\n",
|
||||
"rr = response[\"selection_metadata\"]\n",
|
||||
"print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### chain with default prompt and custom reward prompt (the LLM is used to judge and rank the response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"\n",
|
||||
"llm = OpenAI(engine=\"text-davinci-003\")\n",
|
||||
"\n",
|
||||
"llm('Are you ready?')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"\n",
|
||||
"human_template = \"Given {preference} rank how good or bad this selection is {action}\"\n",
|
||||
"\n",
|
||||
"chain = rl_chain.PickBest.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" prompt=OTHER_PROMPT,\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, scoring_criteria_template_str=human_template),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"actions = [\"an action\", \"another action\", \"a third action\"]\n",
|
||||
"\n",
|
||||
"response = chain.run(\n",
|
||||
" some_text = \"Some text\",\n",
|
||||
" some_other_text = \"Some other text\",\n",
|
||||
" action=rl_chain.ToSelectFrom(actions),\n",
|
||||
" User = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn(\"Vegetarian\"),\n",
|
||||
")\n",
|
||||
"print(response[\"response\"])\n",
|
||||
"rr = response[\"selection_metadata\"]\n",
|
||||
"print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.prompts.prompt import PromptTemplate\n",
|
||||
"\n",
|
||||
"_REWARD_PROMPT_TEMPLATE = \"\"\"Given {preference} rank how good or bad this selection is {action}, IMPORANT: you MUST return a single number between 0 and 1, 0 being bad, 1 being good\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"REWARD_PROMPT = PromptTemplate(\n",
|
||||
" input_variables=[\"preference\", \"action\"],\n",
|
||||
" template=_REWARD_PROMPT_TEMPLATE,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"\n",
|
||||
"human_template = \"Given {preference} rank how good or bad this selection is {action}\"\n",
|
||||
"\n",
|
||||
"chain = rl_chain.PickBest.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" prompt=OTHER_PROMPT,\n",
|
||||
" model_save_dir=\"./models\", # where to save the model checkpoints\n",
|
||||
" selection_scorer=rl_chain.AutoSelectionScorer(llm=llm, prompt=REWARD_PROMPT),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"actions = [\"an action\", \"another action\", \"a third action\"]\n",
|
||||
"\n",
|
||||
"response = chain.run(\n",
|
||||
" some_text = \"Some text\",\n",
|
||||
" some_other_text = \"Some other text\",\n",
|
||||
" action=rl_chain.ToSelectFrom(actions),\n",
|
||||
" User = rl_chain.BasedOn(\"Tom\"),\n",
|
||||
" preference = rl_chain.BasedOn(\"Vegetarian\"),\n",
|
||||
")\n",
|
||||
"print(response[\"response\"])\n",
|
||||
"rr = response[\"selection_metadata\"]\n",
|
||||
"print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"##### other reward options\n",
|
||||
"\n",
|
||||
"custom reward class"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# custom reward class/function is just defining another class that inherits from RewardChecker and implementing the score_response method\n",
|
||||
"import rl_chain\n",
|
||||
"\n",
|
||||
"class CustomSelectionScorer(rl_chain.SelectionScorer):\n",
|
||||
" #grade or score the response\n",
|
||||
" def score_response(\n",
|
||||
" self, inputs, llm_response: str\n",
|
||||
" ) -> float:\n",
|
||||
" # do whatever you want here, use whatever inputs you supplied and return reward\n",
|
||||
" reward = 1.0\n",
|
||||
" return reward\n",
|
||||
" \n",
|
||||
"# set this in the chain during construction (selection_scorer=CustomSelectionScorer()) and it will be auto-called"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Asynchronus user defined reward"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import rl_chain\n",
|
||||
"\n",
|
||||
"chain = rl_chain.PickBest.from_llm(\n",
|
||||
" llm=llm,\n",
|
||||
" prompt=PROMPT,\n",
|
||||
" selection_scorer=None)\n",
|
||||
"\n",
|
||||
"# whenever you have the reward for the call, send it back to the chain to learn from\n",
|
||||
"\n",
|
||||
"response = chain.run(text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
|
||||
" meal = rl_chain.ToSelectFrom(meals),\n",
|
||||
" User = rl_chain.BasedOn(rl_chain.Embed(\"Tom\")),\n",
|
||||
" preference = rl_chain.BasedOn(\"Vegetarian\")\n",
|
||||
" )\n",
|
||||
"print(response[\"response\"])\n",
|
||||
"rr = response[\"selection_metadata\"]\n",
|
||||
"# score should be None here because we turned auto-checkin off\n",
|
||||
"print(f\"score: {rr.selected.score}, action: {rr.selected.index}, probability: {rr.selected.probability}, \")\n",
|
||||
"\n",
|
||||
"# learn delayed score/grade\n",
|
||||
"chain.update_with_delayed_score(score=1.0, event=rr)\n",
|
||||
"\n",
|
||||
"print(f\"score: {rr.selected.score}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -1,275 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import rl_chain_base as base
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from itertools import chain
|
||||
import random
|
||||
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class SlatesFeatureEmbedder(base.Embedder):
|
||||
"""
|
||||
Slates Text Embedder class that embeds the context and actions and slates into a format that can be used by VW
|
||||
|
||||
Attributes:
|
||||
model (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT Sentence Transformer
|
||||
"""
|
||||
|
||||
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if model is None:
|
||||
model = SentenceTransformer("bert-base-nli-mean-tokens")
|
||||
|
||||
self.model = model
|
||||
|
||||
def to_action_features(self, actions: Dict[str, Any]):
|
||||
def _str(embedding):
|
||||
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
|
||||
|
||||
action_features = []
|
||||
for slot in actions.values():
|
||||
slot_features = []
|
||||
for action in slot:
|
||||
if isinstance(action, base._Embed) and action.keep:
|
||||
feature = (
|
||||
action.value.replace(" ", "_")
|
||||
+ " "
|
||||
+ _str(self.model.encode(action.value))
|
||||
)
|
||||
elif isinstance(action, base._Embed):
|
||||
feature = _str(self.model.encode(action.value))
|
||||
else:
|
||||
feature = action.replace(" ", "_")
|
||||
slot_features.append(feature)
|
||||
action_features.append(slot_features)
|
||||
|
||||
return action_features
|
||||
|
||||
def format(self, event: SlatesPersonalizerChain.Event) -> str:
|
||||
action_features = self.to_action_features(event.to_select_from)
|
||||
|
||||
cost = (
|
||||
-1.0 * event.selected.score
|
||||
if event.selected and event.selected.score is not None
|
||||
else ""
|
||||
)
|
||||
context_str = f"slates shared {cost} "
|
||||
|
||||
if event.based_on:
|
||||
embedded_context = base.embed(event.based_on, self.model)
|
||||
for context_item in embedded_context:
|
||||
for ns, ctx in context_item.items():
|
||||
context_str += (
|
||||
f"|{ns} {' '.join(ctx) if isinstance(ctx, list) else ctx} "
|
||||
)
|
||||
else:
|
||||
context_str += "|" # empty context
|
||||
|
||||
actions = chain.from_iterable(
|
||||
[
|
||||
[f"slates action {i} |Action {action}"]
|
||||
for i, slot in enumerate(action_features)
|
||||
for action in slot
|
||||
]
|
||||
)
|
||||
ps = (
|
||||
[f"{a}:{p}" for a, p in event.selected.get_indexes_and_probabilities()]
|
||||
if event.selected
|
||||
else [""] * len(action_features)
|
||||
)
|
||||
slots = [f"slates slot {p} |" for p in ps]
|
||||
return "\n".join(list(chain.from_iterable([[context_str], actions, slots])))
|
||||
|
||||
|
||||
class SlatesRandomPolicy(base.Policy):
|
||||
def __init__(self, feature_embedder: base.Embedder, *_, **__):
|
||||
self.feature_embedder = feature_embedder
|
||||
|
||||
def predict(self, event: SlatesPersonalizerChain.Event) -> Any:
|
||||
return [
|
||||
[(random.randint(0, len(slot) - 1), 1.0 / len(slot))]
|
||||
for _, slot in event.to_select_from.items()
|
||||
]
|
||||
|
||||
def learn(self, event: SlatesPersonalizerChain.Event) -> Any:
|
||||
pass
|
||||
|
||||
def log(self, event: SlatesPersonalizerChain.Event) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class SlatesFirstChoicePolicy(base.Policy):
|
||||
def __init__(self, feature_embedder: base.Embedder, *_, **__):
|
||||
self.feature_embedder = feature_embedder
|
||||
|
||||
def predict(self, event: SlatesPersonalizerChain.Event) -> Any:
|
||||
return [[(0, 1)] for _ in event.to_select_from]
|
||||
|
||||
def learn(self, event: SlatesPersonalizerChain.Event) -> Any:
|
||||
pass
|
||||
|
||||
def log(self, event: SlatesPersonalizerChain.Event) -> Any:
|
||||
pass
|
||||
|
||||
|
||||
class SlatesPersonalizerChain(base.RLChain):
|
||||
class Selected(base.Selected):
|
||||
indexes: Optional[List[int]]
|
||||
probabilities: Optional[List[float]]
|
||||
score: Optional[float]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indexes: Optional[List[int]] = None,
|
||||
probabilities: Optional[List[float]] = None,
|
||||
score: Optional[float] = None,
|
||||
):
|
||||
self.indexes = indexes
|
||||
self.probabilities = probabilities
|
||||
self.score = score
|
||||
|
||||
def get_indexes_and_probabilities(self):
|
||||
return zip(self.indexes, self.probabilities)
|
||||
|
||||
class Event(base.Event):
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
to_select_from: Dict[str, Any],
|
||||
based_on: Dict[str, Any],
|
||||
selected: Optional[SlatesPersonalizerChain.Selected] = None,
|
||||
):
|
||||
super().__init__(inputs=inputs, selected=selected)
|
||||
self.to_select_from = to_select_from
|
||||
self.based_on = based_on
|
||||
|
||||
def __init__(
|
||||
self, feature_embedder: Optional[base.Embedder] = None, *args, **kwargs
|
||||
):
|
||||
vw_cmd = kwargs.get("vw_cmd", [])
|
||||
if not vw_cmd:
|
||||
vw_cmd = [
|
||||
"--slates",
|
||||
"--quiet",
|
||||
"--interactions=::",
|
||||
"--coin",
|
||||
"--squarecb",
|
||||
]
|
||||
else:
|
||||
if "--slates" not in vw_cmd:
|
||||
raise ValueError("If vw_cmd is specified, it must include --slates")
|
||||
|
||||
kwargs["vw_cmd"] = vw_cmd
|
||||
|
||||
if feature_embedder is None:
|
||||
feature_embedder = SlatesFeatureEmbedder()
|
||||
|
||||
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
||||
|
||||
def _call_before_predict(
|
||||
self, inputs: Dict[str, Any]
|
||||
) -> SlatesPersonalizerChain.Event:
|
||||
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
||||
event = SlatesPersonalizerChain.Event(
|
||||
inputs=inputs, to_select_from=actions, based_on=context
|
||||
)
|
||||
return event
|
||||
|
||||
def _call_after_predict_before_llm(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
event: SlatesPersonalizerChain.Event,
|
||||
prediction: List[List[Tuple[int, float]]],
|
||||
) -> Tuple[Dict[str, Any], SlatesPersonalizerChain.Event]:
|
||||
indexes = [p[0][0] for p in prediction]
|
||||
probabilities = [p[0][1] for p in prediction]
|
||||
selected = SlatesPersonalizerChain.Selected(
|
||||
indexes=indexes, probabilities=probabilities
|
||||
)
|
||||
event.selected = selected
|
||||
|
||||
preds = {}
|
||||
for i, (j, a) in enumerate(
|
||||
zip(event.selected.indexes, event.to_select_from.values())
|
||||
):
|
||||
preds[list(event.to_select_from.keys())[i]] = str(a[j])
|
||||
|
||||
next_chain_inputs = inputs.copy()
|
||||
next_chain_inputs.update(preds)
|
||||
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_llm_before_scoring(
|
||||
self, llm_response: str, event: SlatesPersonalizerChain.Event
|
||||
) -> Tuple[Dict[str, Any], SlatesPersonalizerChain.Event]:
|
||||
next_chain_inputs = event.inputs.copy()
|
||||
next_chain_inputs.update(
|
||||
{
|
||||
self.selected_based_on_input_key: str(event.based_on),
|
||||
self.selected_input_key: str(event.to_select_from),
|
||||
}
|
||||
)
|
||||
return next_chain_inputs, event
|
||||
|
||||
def _call_after_scoring_before_learning(
|
||||
self, event: Event, score: Optional[float]
|
||||
) -> SlatesPersonalizerChain.Event:
|
||||
event.selected.score = score
|
||||
return event
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
return super()._call(run_manager=run_manager, inputs=inputs)
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_personalizer_chain"
|
||||
|
||||
@classmethod
|
||||
def from_chain(
|
||||
cls,
|
||||
llm_chain: Chain,
|
||||
prompt: PromptTemplate,
|
||||
selection_scorer=SENTINEL,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if selection_scorer is SENTINEL:
|
||||
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
||||
return SlatesPersonalizerChain(
|
||||
llm_chain=llm_chain,
|
||||
prompt=prompt,
|
||||
selection_scorer=selection_scorer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: PromptTemplate,
|
||||
selection_scorer=SENTINEL,
|
||||
**kwargs: Any,
|
||||
):
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
return SlatesPersonalizerChain.from_chain(
|
||||
llm_chain=llm_chain,
|
||||
prompt=prompt,
|
||||
selection_scorer=selection_scorer,
|
||||
**kwargs,
|
||||
)
|
Loading…
Reference in New Issue
Block a user