Merge pull request #1 from VowpalWabbit/add_rl_chain

Initial commit of rl_chain code
This commit is contained in:
olgavrou 2023-08-22 09:18:23 -04:00 committed by GitHub
commit e9423300d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1988 additions and 0 deletions

View File

@ -0,0 +1,28 @@
from langchain.chains.rl_chain.pick_best_chain import PickBest
from langchain.chains.rl_chain.base import (
Embed,
BasedOn,
ToSelectFrom,
SelectionScorer,
AutoSelectionScorer,
Embedder,
Policy,
VwPolicy,
)
import logging
def configure_logger():
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
ch.setLevel(logging.INFO)
logger.addHandler(ch)
configure_logger()

View File

@ -0,0 +1,551 @@
from __future__ import annotations
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence
from abc import ABC, abstractmethod
import vowpal_wabbit_next as vw
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.metrics import MetricsTracker
from langchain.prompts import BasePromptTemplate
from langchain.pydantic_v1 import Extra, BaseModel, root_validator
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
logger = logging.getLogger(__name__)
class _BasedOn:
def __init__(self, value):
self.value = value
def __str__(self):
return str(self.value)
__repr__ = __str__
def BasedOn(anything):
return _BasedOn(anything)
class _ToSelectFrom:
def __init__(self, value):
self.value = value
def __str__(self):
return str(self.value)
__repr__ = __str__
def ToSelectFrom(anything):
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
class _Embed:
def __init__(self, value, keep=False):
self.value = value
self.keep = keep
def __str__(self):
return str(self.value)
__repr__ = __str__
def Embed(anything, keep=False):
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
return BasedOn(Embed(anything.value, keep=keep))
if isinstance(anything, list):
return [Embed(v, keep=keep) for v in anything]
elif isinstance(anything, dict):
return {k: Embed(v, keep=keep) for k, v in anything.items()}
elif isinstance(anything, _Embed):
return anything
return _Embed(anything, keep=keep)
def EmbedAndKeep(anything):
return Embed(anything, keep=True)
# helper functions
def parse_lines(parser: vw.TextFormatParser, input_str: str) -> List[vw.Example]:
return [parser.parse_line(line) for line in input_str.split("\n")]
def get_based_on_and_to_select_from(inputs: Dict[str, Any]):
to_select_from = {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
}
if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)
based_on = {
k: inputs[k].value if isinstance(inputs[k].value, list) else [inputs[k].value]
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
}
return based_on, to_select_from
def prepare_inputs_for_autoembed(inputs: Dict[str, Any]):
# go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if
# their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status
next_inputs = inputs.copy()
for k, v in next_inputs.items():
if isinstance(v, _ToSelectFrom) or isinstance(v, _BasedOn):
if not isinstance(v.value, _Embed):
next_inputs[k].value = EmbedAndKeep(v.value)
return next_inputs
# end helper functions
class Selected(ABC):
pass
class Event(ABC):
inputs: Dict[str, Any]
selected: Optional[Selected]
def __init__(self, inputs: Dict[str, Any], selected: Optional[Selected] = None):
self.inputs = inputs
self.selected = selected
class Policy(ABC):
@abstractmethod
def predict(self, event: Event) -> Any:
pass
@abstractmethod
def learn(self, event: Event):
pass
@abstractmethod
def log(self, event: Event):
pass
def save(self):
pass
class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: Sequence[str],
feature_embedder: Embedder,
vw_logger: VwLogger,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.model_repo = model_repo
self.workspace = self.model_repo.load(vw_cmd)
self.feature_embedder = feature_embedder
self.vw_logger = vw_logger
def predict(self, event: Event) -> Any:
text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(
parse_lines(text_parser, self.feature_embedder.format(event))
)
def learn(self, event: Event):
vw_ex = self.feature_embedder.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)
def log(self, event: Event):
if self.vw_logger.logging_enabled():
vw_ex = self.feature_embedder.format(event)
self.vw_logger.log(vw_ex)
def save(self):
self.model_repo.save()
class Embedder(ABC):
@abstractmethod
def format(self, event: Event) -> str:
pass
class SelectionScorer(ABC, BaseModel):
"""Abstract method to grade the chosen selection or the response of the llm"""
@abstractmethod
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
pass
class AutoSelectionScorer(SelectionScorer, BaseModel):
llm_chain: Union[LLMChain, None] = None
prompt: Union[BasePromptTemplate, None] = None
scoring_criteria_template_str: Optional[str] = None
@staticmethod
def get_default_system_prompt() -> SystemMessagePromptTemplate:
return SystemMessagePromptTemplate.from_template(
"PLEASE RESPOND ONLY WITH A SIGNLE FLOAT AND NO OTHER TEXT EXPLANATION\n You are a strict judge that is called on to rank a response based on given criteria.\
You must respond with your ranking by providing a single float within the range [0, 1], 0 being very bad response and 1 being very good response."
)
@staticmethod
def get_default_prompt() -> ChatPromptTemplate:
human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{llm_response}".'
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
chat_prompt = ChatPromptTemplate.from_messages(
[default_system_prompt, human_message_prompt]
)
return chat_prompt
@root_validator(pre=True)
def set_prompt_and_llm_chain(cls, values):
llm = values.get("llm")
prompt = values.get("prompt")
scoring_criteria_template_str = values.get("scoring_criteria_template_str")
if prompt is None and scoring_criteria_template_str is None:
prompt = AutoSelectionScorer.get_default_prompt()
elif prompt is None and scoring_criteria_template_str is not None:
human_message_prompt = HumanMessagePromptTemplate.from_template(
scoring_criteria_template_str
)
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
prompt = ChatPromptTemplate.from_messages(
[default_system_prompt, human_message_prompt]
)
values["prompt"] = prompt
values["llm_chain"] = LLMChain(llm=llm, prompt=prompt)
return values
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
ranking = self.llm_chain.predict(llm_response=llm_response, **inputs)
ranking = ranking.strip()
try:
resp = float(ranking)
return resp
except Exception as e:
raise RuntimeError(
f"The llm did not manage to rank the response as expected, there is always the option to try again or tweak the reward prompt. Error: {e}"
)
class RLChain(Chain):
"""
RLChain class that utilizes the Vowpal Wabbit (VW) model for personalization.
Attributes:
model_loading (bool, optional): If set to True, the chain will attempt to load an existing VW model from the latest checkpoint file in the {model_save_dir} directory (current directory if none specified). If set to False, it will start training from scratch, potentially overwriting existing files. Defaults to True.
large_action_spaces (bool, optional): If set to True and vw_cmd has not been specified in the constructor, it will enable large action spaces
vw_cmd (List[str], optional): Advanced users can set the VW command line to whatever they want, as long as it is compatible with the Type that is specified (Type Enum)
model_save_dir (str, optional): The directory to save the VW model to. Defaults to the current directory.
selection_scorer (SelectionScorer): If set, the chain will check the response using the provided selection_scorer and the VW model will be updated with the result. Defaults to None.
Notes:
The class creates a VW model instance using the provided arguments. Before the chain object is destroyed the save_progress() function can be called. If it is called, the learned VW model is saved to a file in the current directory named `model-<checkpoint>.vw`. Checkpoints start at 1 and increment monotonically.
When making predictions, VW is first called to choose action(s) which are then passed into the prompt with the key `{actions}`. After action selection, the LLM (Language Model) is called with the prompt populated by the chosen action(s), and the response is returned.
"""
llm_chain: Chain
output_key: str = "result" #: :meta private:
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
policy: Optional[Policy]
auto_embed: bool = True
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
metrics: Optional[MetricsTracker] = None
def __init__(
self,
feature_embedder: Embedder,
model_save_dir="./",
reset_model=False,
vw_cmd=None,
policy=VwPolicy,
vw_logs: Optional[Union[str, os.PathLike]] = None,
metrics_step=-1,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if self.selection_scorer is None:
logger.warning(
"No response validator provided, which means that no reinforcement learning will be done in the RL chain unless update_with_delayed_score is called."
)
self.policy = policy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd or [],
feature_embedder=feature_embedder,
vw_logger=VwLogger(vw_logs),
)
self.metrics = MetricsTracker(step=metrics_step)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return []
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
super()._validate_inputs(inputs)
if (
self.selected_input_key in inputs.keys()
or self.selected_based_on_input_key in inputs.keys()
):
raise ValueError(
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."
)
@abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
pass
@abstractmethod
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: Any
) -> Tuple[Dict[str, Any], Event]:
pass
@abstractmethod
def _call_after_llm_before_scoring(
self, llm_response: str, event: Event
) -> Tuple[Dict[str, Any], Event]:
pass
@abstractmethod
def _call_after_scoring_before_learning(
self, event: Event, score: Optional[float]
) -> Event:
pass
def update_with_delayed_score(
self, score: float, event: Event, force_score=False
) -> None:
"""
Learn will be called with the score specified and the actions/embeddings/etc stored in event
Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
"""
if self.selection_scorer and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
)
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event)
self.policy.log(event=event)
def set_auto_embed(self, auto_embed: bool) -> None:
"""
Set whether the chain should auto embed the inputs or not. If set to False, the inputs will not be embedded and the user will need to embed the inputs themselves before calling run.
Args:
auto_embed (bool): Whether the chain should auto embed the inputs or not.
"""
self.auto_embed = auto_embed
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
if self.auto_embed:
inputs = prepare_inputs_for_autoembed(inputs=inputs)
event = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event)
self.metrics.on_decision()
next_chain_inputs, event = self._call_after_predict_before_llm(
inputs=inputs, event=event, prediction=prediction
)
t = self.llm_chain.run(**next_chain_inputs, callbacks=_run_manager.get_child())
_run_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()
if self.verbose:
_run_manager.on_text("\nCode: ", verbose=self.verbose)
output = t
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
next_chain_inputs, event = self._call_after_llm_before_scoring(
llm_response=output, event=event
)
score = None
try:
if self.selection_scorer:
score = self.selection_scorer.score_response(
inputs=next_chain_inputs, llm_response=output
)
except Exception as e:
logger.info(
f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}"
)
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event)
self.policy.learn(event=event)
self.policy.log(event=event)
return {self.output_key: {"response": output, "selection_metadata": event}}
def save_progress(self) -> None:
"""
This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file.
File Naming Convention:
The file will be named using the pattern `model-<checkpoint>.vw`, where `<checkpoint>` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `<checkpoint>` will be the next in the sequence.
Example:
If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`.
Note:
Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers.
"""
self.policy.save()
@property
def _chain_type(self) -> str:
return "llm_personalizer_chain"
def is_stringtype_instance(item: Any) -> bool:
"""Helper function to check if an item is a string."""
return isinstance(item, str) or (
isinstance(item, _Embed) and isinstance(item.value, str)
)
def embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, str]:
"""Helper function to embed a string or an _Embed object."""
join_char = ""
keep_str = ""
if isinstance(item, _Embed):
encoded = model.encode(item.value)
join_char = " "
if item.keep:
keep_str = item.value.replace(" ", "_") + " "
elif isinstance(item, str):
encoded = item.replace(" ", "_")
join_char = ""
else:
raise ValueError(f"Unsupported type {type(item)} for embedding")
if namespace is None:
raise ValueError(
"The default namespace must be provided when embedding a string or _Embed object."
)
return {namespace: keep_str + join_char.join(map(str, encoded))}
def embed_dict_type(item: Dict, model: Any) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a dictionary item."""
inner_dict = {}
for ns, embed_item in item.items():
if isinstance(embed_item, list):
inner_dict[ns] = []
for embed_list_item in embed_item:
embedded = embed_string_type(embed_list_item, model, ns)
inner_dict[ns].append(embedded[ns])
else:
inner_dict.update(embed_string_type(embed_item, model, ns))
return inner_dict
def embed_list_type(
item: list, model: Any, namespace: Optional[str] = None
) -> List[Dict[str, Union[str, List[str]]]]:
ret_list = []
for embed_item in item:
if isinstance(embed_item, dict):
ret_list.append(embed_dict_type(embed_item, model))
else:
ret_list.append(embed_string_type(embed_item, model, namespace))
return ret_list
def embed(
to_embed: Union[
Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]
],
model: Any,
namespace: Optional[str] = None,
) -> List[Dict[str, Union[str, List[str]]]]:
"""
Embeds the actions or context using the SentenceTransformer model
Attributes:
to_embed: (Union[Union(str, _Embed(str)), Dict, List[Union(str, _Embed(str))], List[Dict]], required) The text to be embedded, either a string, a list of strings or a dictionary or a list of dictionaries.
namespace: (str, optional) The default namespace to use when dictionary or list of dictionaries not provided.
model: (Any, required) The model to use for embedding
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
"""
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
to_embed, str
):
return [embed_string_type(to_embed, model, namespace)]
elif isinstance(to_embed, dict):
return [embed_dict_type(to_embed, model)]
elif isinstance(to_embed, list):
return embed_list_type(to_embed, model, namespace)
else:
raise ValueError("Invalid input format for embedding")

View File

@ -0,0 +1,27 @@
import pandas as pd
from typing import Optional
class MetricsTracker:
def __init__(self, step: int):
self._history = []
self._step = step
self._i = 0
self._num = 0
self._denom = 0
@property
def score(self) -> float:
return self._num / self._denom if self._denom > 0 else 0
def on_decision(self) -> None:
self._denom += 1
def on_feedback(self, score: Optional[float]) -> None:
self._num += score or 0
self._i += 1
if self._step > 0 and self._i % self._step == 0:
self._history.append({"step": self._i, "score": self.score})
def to_pandas(self) -> pd.DataFrame:
return pd.DataFrame(self._history)

View File

@ -0,0 +1,53 @@
from pathlib import Path
import shutil
import datetime
import vowpal_wabbit_next as vw
from typing import Union, Sequence
import os
import glob
import logging
logger = logging.getLogger(__name__)
class ModelRepository:
def __init__(
self,
folder: Union[str, os.PathLike],
with_history: bool = True,
reset: bool = False,
):
self.folder = Path(folder)
self.model_path = self.folder / "latest.vw"
self.with_history = with_history
if reset and self.has_history:
logger.warning(
"There is non empty history which is recommended to be cleaned up"
)
if self.model_path.exists():
os.remove(self.model_path)
self.folder.mkdir(parents=True, exist_ok=True)
def get_tag(self) -> str:
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
def has_history(self) -> bool:
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
def save(self, workspace: vw.Workspace) -> None:
with open(self.model_path, "wb") as f:
logger.info(f"storing rl_chain model in: {self.model_path}")
f.write(workspace.serialize())
if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
def load(self, commandline: Sequence[str]) -> vw.Workspace:
model_data = None
if self.model_path.exists():
with open(self.model_path, "rb") as f:
model_data = f.read()
if model_data:
logger.info(f"rl_chain model is loaded from: {self.model_path}")
return vw.Workspace(commandline, model_data=model_data)
return vw.Workspace(commandline)

View File

@ -0,0 +1,284 @@
from __future__ import annotations
import langchain.chains.rl_chain.base as base
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from sentence_transformers import SentenceTransformer
from langchain.prompts import BasePromptTemplate
import logging
logger = logging.getLogger(__name__)
# sentinel object used to distinguish between user didn't supply anything or user explicitly supplied None
SENTINEL = object()
class PickBestFeatureEmbedder(base.Embedder):
"""
Contextual Bandit Text Embedder class that embeds the based_on and to_select_from into a format that can be used by VW
Attributes:
model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer.
"""
def __init__(self, model: Optional[Any] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
if model is None:
model = SentenceTransformer("bert-base-nli-mean-tokens")
self.model = model
def format(self, event: PickBest.Event) -> str:
"""
Converts the based_on and to_select_from into a format that can be used by VW
"""
cost = None
if event.selected:
chosen_action = event.selected.index
cost = (
-1.0 * event.selected.score
if event.selected.score is not None
else None
)
prob = event.selected.probability
context_emb = base.embed(event.based_on, self.model) if event.based_on else None
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)
action_embs = (
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if not context_emb or not action_embs:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
)
example_string = ""
example_string += f"shared "
for context_item in context_emb:
for ns, based_on in context_item.items():
example_string += f"|{ns} {' '.join(based_on) if isinstance(based_on, list) else based_on} "
example_string += "\n"
for i, action in enumerate(action_embs):
if cost is not None and chosen_action == i:
example_string += f"{chosen_action}:{cost}:{prob} "
for ns, action_embedding in action.items():
example_string += f"|{ns} {' '.join(action_embedding) if isinstance(action_embedding, list) else action_embedding} "
example_string += "\n"
# Strip the last newline
return example_string[:-1]
class PickBest(base.RLChain):
"""
PickBest class that utilizes the Vowpal Wabbit (VW) model for personalization.
The Chain is initialized with a set of potential to_select_from. For each call to the Chain, a specific action will be chosen based on an input based_on.
This chosen action is then passed to the prompt that will be utilized in the subsequent call to the LLM (Language Model).
The flow of this chain is:
- Chain is initialized
- Chain is called input containing the based_on and the List of potential to_select_from
- Chain chooses an action based on the based_on
- Chain calls the LLM with the chosen action
- LLM returns a response
- If the selection_scorer is specified, the response is checked against the selection_scorer
- The internal model will be updated with the based_on, action, and reward of the response (how good or bad the response was)
- The response is returned
input dictionary expects:
- at least one variable wrapped in BasedOn which will be the based_on to use for personalization
- one variable of a list wrapped in ToSelectFrom which will be the list of to_select_from for the Vowpal Wabbit model to choose from.
This list can either be a List of str's or a List of Dict's.
- Actions provided as a list of strings e.g. to_select_from = ["action1", "action2", "action3"]
- If to_select_from are provided as a list of dictionaries, each action should be a dictionary where the keys are namespace names and the values are the corresponding action strings e.g. to_select_from = [{"namespace1": "action1", "namespace2": "action2"}, {"namespace1": "action3", "namespace2": "action4"}]
Extends:
RLChain
Attributes:
feature_embedder: (PickBestFeatureEmbedder, optional) The text embedder to use for embedding the based_on and the to_select_from. If not provided, a default embedder is used.
"""
class Selected(base.Selected):
index: Optional[int]
probability: Optional[float]
score: Optional[float]
def __init__(
self,
index: Optional[int] = None,
probability: Optional[float] = None,
score: Optional[float] = None,
):
self.index = index
self.probability = probability
self.score = score
class Event(base.Event):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBest.Selected] = None,
):
super().__init__(inputs=inputs, selected=selected)
self.to_select_from = to_select_from
self.based_on = based_on
def __init__(
self,
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
*args,
**kwargs,
):
vw_cmd = kwargs.get("vw_cmd", [])
if not vw_cmd:
vw_cmd = [
"--cb_explore_adf",
"--quiet",
"--interactions=::",
"--coin",
"--epsilon=0.2",
]
else:
if "--cb_explore_adf" not in vw_cmd:
raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf"
)
kwargs["vw_cmd"] = vw_cmd
if not feature_embedder:
feature_embedder = PickBestFeatureEmbedder()
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBest.Event:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)
if len(list(actions.values())) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for the PickBest chain. Please provide only one variable containing a list to select from."
)
if not context:
raise ValueError(
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
)
event = PickBest.Event(inputs=inputs, to_select_from=actions, based_on=context)
return event
def _call_after_predict_before_llm(
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
) -> Tuple[Dict[str, Any], PickBest.Event]:
prob_sum = sum(prob for _, prob in prediction)
probabilities = [prob / prob_sum for _, prob in prediction]
## sample from the pmf
sampled_index = np.random.choice(len(prediction), p=probabilities)
sampled_ap = prediction[sampled_index]
sampled_action = sampled_ap[0]
sampled_prob = sampled_ap[1]
selected = PickBest.Selected(index=sampled_action, probability=sampled_prob)
event.selected = selected
# only one key, value pair in event.to_select_from
key, value = next(iter(event.to_select_from.items()))
next_chain_inputs = inputs.copy()
next_chain_inputs.update({key: value[event.selected.index]})
return next_chain_inputs, event
def _call_after_llm_before_scoring(
self, llm_response: str, event: PickBest.Event
) -> Tuple[Dict[str, Any], PickBest.Event]:
next_chain_inputs = event.inputs.copy()
# only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values()))
next_chain_inputs.update(
{
self.selected_based_on_input_key: str(event.based_on),
self.selected_input_key: value[event.selected.index],
}
)
return next_chain_inputs, event
def _call_after_scoring_before_learning(
self, event: PickBest.Event, score: Optional[float]
) -> Event:
event.selected.score = score
return event
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""
When chain.run() is called with the given inputs, this function is called. It is responsible for calling the VW model to choose an action (ToSelectFrom) based on the (BasedOn) based_on, and then calling the LLM (Language Model) with the chosen action to generate a response.
Attributes:
inputs: (Dict, required) The inputs to the chain. The inputs must contain a input variables that are wrapped in BasedOn and ToSelectFrom. BasedOn is the based_on that will be used for selecting an ToSelectFrom action that will be passed to the LLM prompt.
run_manager: (CallbackManagerForChainRun, optional) The callback manager to use for this run. If not provided, a default callback manager is used.
Returns:
A dictionary containing:
- `response`: The response generated by the LLM (Language Model).
- `selection_metadata`: A Event object containing all the information needed to learn the reward for the chosen action at a later point. If an automatic selection_scorer is not provided, then this object can be used at a later point with the `update_with_delayed_score()` function to learn the delayed reward and update the Vowpal Wabbit model.
- the `score` in the `selection_metadata` object is set to None if an automatic selection_scorer is not provided or if the selection_scorer failed (e.g. LLM timeout or LLM failed to rank correctly).
"""
return super()._call(run_manager=run_manager, inputs=inputs)
@property
def _chain_type(self) -> str:
return "rl_chain_pick_best"
@classmethod
def from_chain(
cls,
llm_chain: Chain,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
**kwargs: Any,
):
if selection_scorer is SENTINEL:
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
return PickBest(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,
**kwargs,
)
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
**kwargs: Any,
):
llm_chain = LLMChain(llm=llm, prompt=prompt)
return PickBest.from_chain(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,
**kwargs,
)

View File

@ -0,0 +1,18 @@
from typing import Union, Optional
from pathlib import Path
from os import PathLike
class VwLogger:
def __init__(self, path: Optional[Union[str, PathLike]]):
self.path = Path(path) if path else None
if self.path:
self.path.parent.mkdir(parents=True, exist_ok=True)
def log(self, vw_ex: str):
if self.path:
with open(self.path, "a") as f:
f.write(f"{vw_ex}\n\n")
def logging_enabled(self):
return bool(self.path)

View File

@ -125,6 +125,7 @@ newspaper3k = {version = "^0.2.8", optional = true}
amazon-textract-caller = {version = "<2", optional = true}
xata = {version = "^1.0.0a7", optional = true}
xmltodict = {version = "^0.13.0", optional = true}
vowpal-wabbit-next = "0.6.0"
[tool.poetry.group.test.dependencies]

View File

@ -0,0 +1,286 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
import langchain.chains.rl_chain.base as rl_chain
from test_utils import MockEncoder
import pytest
from langchain.prompts.prompt import PromptTemplate
from langchain.chat_models import FakeListChatModel
encoded_text = "[ e n c o d e d ] "
def setup():
_PROMPT_TEMPLATE = """This is a dummy prompt that will be ignored by the fake llm"""
PROMPT = PromptTemplate(input_variables=[], template=_PROMPT_TEMPLATE)
llm = FakeListChatModel(responses=["hey"])
return llm, PROMPT
def test_multiple_ToSelectFrom_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
another_action=rl_chain.ToSelectFrom(actions),
)
def test_missing_basedOn_from_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
chain.run(action=rl_chain.ToSelectFrom(actions))
def test_ToSelectFrom_not_a_list_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
actions = {"actions": ["0", "1", "2"]}
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
def test_update_with_delayed_score_with_auto_validator_throws():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0
with pytest.raises(RuntimeError):
chain.update_with_delayed_score(event=selection_metadata, score=100)
def test_update_with_delayed_score_force():
llm, PROMPT = setup()
# this LLM returns a number so that the auto validator will return that
auto_val_llm = FakeListChatModel(responses=["3"])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 3.0
chain.update_with_delayed_score(
event=selection_metadata, score=100, force_score=True
)
assert selection_metadata.selected.score == 100.0
def test_update_with_delayed_score():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == None
chain.update_with_delayed_score(event=selection_metadata, score=100)
assert selection_metadata.selected.score == 100.0
def test_user_defined_scorer():
llm, PROMPT = setup()
class CustomSelectionScorer(rl_chain.SelectionScorer):
def score_response(self, inputs, llm_response: str) -> float:
score = 200
return score
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer()
)
actions = ["0", "1", "2"]
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(actions),
)
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 200.0
def test_default_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
expected = f"""shared |User {ctx_str_1 + " " + encoded_ctx_str_1} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {str3 + " " + encoded_str3} """
actions = [str1, str2, str3]
response = chain.run(
User=rl_chain.BasedOn(ctx_str_1),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
assert vw_str == expected
def test_default_embeddings_off():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
)
str1 = "0"
str2 = "1"
str3 = "2"
ctx_str_1 = "context1"
expected = f"""shared |User {ctx_str_1} \n|action {str1} \n|action {str2} \n|action {str3} """
actions = [str1, str2, str3]
response = chain.run(
User=rl_chain.BasedOn(ctx_str_1),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
assert vw_str == expected
def test_default_embeddings_mixed_w_explicit_user_embeddings():
llm, PROMPT = setup()
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
)
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
expected = f"""shared |User {encoded_ctx_str_1} |User2 {ctx_str_2 + " " + encoded_ctx_str_2} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {encoded_str3} """
actions = [str1, str2, rl_chain.Embed(str3)]
response = chain.run(
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
User2=rl_chain.BasedOn(ctx_str_2),
action=rl_chain.ToSelectFrom(actions),
)
selection_metadata = response["selection_metadata"]
vw_str = feature_embedder.format(selection_metadata)
assert vw_str == expected
def test_default_no_scorer_specified():
_, PROMPT = setup()
chain_llm = FakeListChatModel(responses=[100])
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "100"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 100.0
def test_explicitly_no_scorer():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(
llm=llm, prompt=PROMPT, selection_scorer=None
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == None
def test_auto_scorer_with_user_defined_llm():
llm, PROMPT = setup()
scorer_llm = FakeListChatModel(responses=[300])
chain = pick_best_chain.PickBest.from_llm(
llm=llm,
prompt=PROMPT,
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
)
response = chain.run(
User=rl_chain.BasedOn("Context"),
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
)
# chain llm used for both basic prompt and for scoring
assert response["response"] == "hey"
selection_metadata = response["selection_metadata"]
assert selection_metadata.selected.score == 300.0
def test_calling_chain_w_reserved_inputs_throws():
llm, PROMPT = setup()
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
rl_chain_selected_based_on=rl_chain.ToSelectFrom(
["0", "1", "2"]
),
)
with pytest.raises(ValueError):
chain.run(
User=rl_chain.BasedOn("Context"),
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
)

View File

@ -0,0 +1,331 @@
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
import langchain.chains.rl_chain.base as rl_chain
from test_utils import MockEncoder
import pytest
encoded_text = "[ e n c o d e d ] "
def test_pickbest_textembedder_missing_context_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_action = {"action": ["0", "1", "2"]}
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_action, based_on={}
)
with pytest.raises(ValueError):
feature_embedder.format(event)
def test_pickbest_textembedder_missing_actions_throws():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from={}, based_on={"context": "context"}
)
with pytest.raises(ValueError):
feature_embedder.format(event)
def test_pickbest_textembedder_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on={"context": "context"}
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_w_label_no_score_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = """shared |context context \n|action1 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
selected=selected,
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": ["0", "1", "2"]}
expected = (
"""shared |context context \n0:-0.0:1.0 |action1 0 \n|action1 1 \n|action1 2 """
)
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={},
to_select_from=named_actions,
based_on={"context": "context"},
selected=selected,
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_w_full_label_w_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
named_actions = {"action1": rl_chain.Embed([str1, str2, str3])}
context = {"context": rl_chain.Embed(ctx_str_1)}
expected = f"""shared |context {encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_w_full_label_w_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
named_actions = {"action1": rl_chain.EmbedAndKeep([str1, str2, str3])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str_1)}
expected = f"""shared |context {ctx_str_1 + " " + encoded_ctx_str_1} \n0:-0.0:1.0 |action1 {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_more_namespaces_no_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_more_namespaces_w_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n|a 0 |b 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]}
context = {"context1": "context1", "context2": "context2"}
expected = """shared |context1 context1 |context2 context2 \n0:-0.0:1.0 |a 0 |b 0 \n|action1 1 \n|action1 2 """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])
}
context = {
"context1": rl_chain.Embed(ctx_str_1),
"context2": rl_chain.Embed(ctx_str_2),
}
expected = f"""shared |context1 {encoded_ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {encoded_str1} |b {encoded_str1} \n|action1 {encoded_str2} \n|action1 {encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": rl_chain.EmbedAndKeep(
[{"a": str1, "b": str1}, str2, str3]
)
}
context = {
"context1": rl_chain.EmbedAndKeep(ctx_str_1),
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
}
expected = f"""shared |context1 {ctx_str_1 + " " + encoded_ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1 + " " + encoded_str1} |b {str1 + " " + encoded_str1} \n|action1 {str2 + " " + encoded_str2} \n|action1 {str3 + " " + encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": [
{"a": str1, "b": rl_chain.Embed(str1)},
str2,
rl_chain.Embed(str3),
]
}
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_embed_and_keep():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "0"
str2 = "1"
str3 = "2"
encoded_str1 = encoded_text + " ".join(char for char in str1)
encoded_str2 = encoded_text + " ".join(char for char in str2)
encoded_str3 = encoded_text + " ".join(char for char in str3)
ctx_str_1 = "context1"
ctx_str_2 = "context2"
encoded_ctx_str_1 = encoded_text + " ".join(char for char in ctx_str_1)
encoded_ctx_str_2 = encoded_text + " ".join(char for char in ctx_str_2)
named_actions = {
"action1": [
{"a": str1, "b": rl_chain.EmbedAndKeep(str1)},
str2,
rl_chain.EmbedAndKeep(str3),
]
}
context = {
"context1": ctx_str_1,
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
}
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """
selected = pick_best_chain.PickBest.Selected(index=0, probability=1.0, score=0.0)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context, selected=selected
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected
def test_raw_features_underscored():
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
str1 = "this is a long string"
str1_underscored = str1.replace(" ", "_")
encoded_str1 = encoded_text + " ".join(char for char in str1)
ctx_str = "this is a long context"
ctx_str_underscored = ctx_str.replace(" ", "_")
encoded_ctx_str = encoded_text + " ".join(char for char in ctx_str)
# No embeddings
named_actions = {"action": [str1]}
context = {"context": ctx_str}
expected_no_embed = (
f"""shared |context {ctx_str_underscored} \n|action {str1_underscored} """
)
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_no_embed
# Just embeddings
named_actions = {"action": rl_chain.Embed([str1])}
context = {"context": rl_chain.Embed(ctx_str)}
expected_embed = f"""shared |context {encoded_ctx_str} \n|action {encoded_str1} """
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_embed
# Embeddings and raw features
named_actions = {"action": rl_chain.EmbedAndKeep([str1])}
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected_embed_and_keep = f"""shared |context {ctx_str_underscored + " " + encoded_ctx_str} \n|action {str1_underscored + " " + encoded_str1} """
event = pick_best_chain.PickBest.Event(
inputs={}, to_select_from=named_actions, based_on=context
)
vw_ex_str = feature_embedder.format(event)
assert vw_ex_str == expected_embed_and_keep

View File

@ -0,0 +1,406 @@
import langchain.chains.rl_chain.base as base
from test_utils import MockEncoder
import pytest
encoded_text = "[ e n c o d e d ] "
def test_simple_context_str_no_emb():
expected = [{"a_namespace": "test"}]
assert base.embed("test", MockEncoder(), "a_namespace") == expected
def test_simple_context_str_w_emb():
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"a_namespace": encoded_text + encoded_str1}]
assert base.embed(base.Embed(str1), MockEncoder(), "a_namespace") == expected
expected_embed_and_keep = [
{"a_namespace": str1 + " " + encoded_text + encoded_str1}
]
assert (
base.embed(base.EmbedAndKeep(str1), MockEncoder(), "a_namespace")
== expected_embed_and_keep
)
def test_simple_context_str_w_nested_emb():
# nested embeddings, innermost wins
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"a_namespace": encoded_text + encoded_str1}]
assert (
base.embed(base.EmbedAndKeep(base.Embed(str1)), MockEncoder(), "a_namespace")
== expected
)
expected2 = [{"a_namespace": str1 + " " + encoded_text + encoded_str1}]
assert (
base.embed(base.Embed(base.EmbedAndKeep(str1)), MockEncoder(), "a_namespace")
== expected2
)
def test_context_w_namespace_no_emb():
expected = [{"test_namespace": "test"}]
assert base.embed({"test_namespace": "test"}, MockEncoder()) == expected
def test_context_w_namespace_w_emb():
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}]
assert base.embed({"test_namespace": base.Embed(str1)}, MockEncoder()) == expected
expected_embed_and_keep = [
{"test_namespace": str1 + " " + encoded_text + encoded_str1}
]
assert (
base.embed({"test_namespace": base.EmbedAndKeep(str1)}, MockEncoder())
== expected_embed_and_keep
)
def test_context_w_namespace_w_emb2():
str1 = "test"
encoded_str1 = " ".join(char for char in str1)
expected = [{"test_namespace": encoded_text + encoded_str1}]
assert base.embed(base.Embed({"test_namespace": str1}), MockEncoder()) == expected
expected_embed_and_keep = [
{"test_namespace": str1 + " " + encoded_text + encoded_str1}
]
assert (
base.embed(base.EmbedAndKeep({"test_namespace": str1}), MockEncoder())
== expected_embed_and_keep
)
def test_context_w_namespace_w_some_emb():
str1 = "test1"
str2 = "test2"
encoded_str2 = " ".join(char for char in str2)
expected = [
{"test_namespace": str1, "test_namespace2": encoded_text + encoded_str2}
]
assert (
base.embed(
{"test_namespace": str1, "test_namespace2": base.Embed(str2)}, MockEncoder()
)
== expected
)
expected_embed_and_keep = [
{
"test_namespace": str1,
"test_namespace2": str2 + " " + encoded_text + encoded_str2,
}
]
assert (
base.embed(
{"test_namespace": str1, "test_namespace2": base.EmbedAndKeep(str2)},
MockEncoder(),
)
== expected_embed_and_keep
)
def test_simple_action_strlist_no_emb():
str1 = "test1"
str2 = "test2"
str3 = "test3"
expected = [{"a_namespace": str1}, {"a_namespace": str2}, {"a_namespace": str3}]
assert base.embed([str1, str2, str3], MockEncoder(), "a_namespace") == expected
def test_simple_action_strlist_w_emb():
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = " ".join(char for char in str1)
encoded_str2 = " ".join(char for char in str2)
encoded_str3 = " ".join(char for char in str3)
expected = [
{"a_namespace": encoded_text + encoded_str1},
{"a_namespace": encoded_text + encoded_str2},
{"a_namespace": encoded_text + encoded_str3},
]
assert (
base.embed(base.Embed([str1, str2, str3]), MockEncoder(), "a_namespace")
== expected
)
expected_embed_and_keep = [
{"a_namespace": str1 + " " + encoded_text + encoded_str1},
{"a_namespace": str2 + " " + encoded_text + encoded_str2},
{"a_namespace": str3 + " " + encoded_text + encoded_str3},
]
assert (
base.embed(base.EmbedAndKeep([str1, str2, str3]), MockEncoder(), "a_namespace")
== expected_embed_and_keep
)
def test_simple_action_strlist_w_some_emb():
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str2 = " ".join(char for char in str2)
encoded_str3 = " ".join(char for char in str3)
expected = [
{"a_namespace": str1},
{"a_namespace": encoded_text + encoded_str2},
{"a_namespace": encoded_text + encoded_str3},
]
assert (
base.embed(
[str1, base.Embed(str2), base.Embed(str3)], MockEncoder(), "a_namespace"
)
== expected
)
expected_embed_and_keep = [
{"a_namespace": str1},
{"a_namespace": str2 + " " + encoded_text + encoded_str2},
{"a_namespace": str3 + " " + encoded_text + encoded_str3},
]
assert (
base.embed(
[str1, base.EmbedAndKeep(str2), base.EmbedAndKeep(str3)],
MockEncoder(),
"a_namespace",
)
== expected_embed_and_keep
)
def test_action_w_namespace_no_emb():
str1 = "test1"
str2 = "test2"
str3 = "test3"
expected = [
{"test_namespace": str1},
{"test_namespace": str2},
{"test_namespace": str3},
]
assert (
base.embed(
[
{"test_namespace": str1},
{"test_namespace": str2},
{"test_namespace": str3},
],
MockEncoder(),
)
== expected
)
def test_action_w_namespace_w_emb():
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = " ".join(char for char in str1)
encoded_str2 = " ".join(char for char in str2)
encoded_str3 = " ".join(char for char in str3)
expected = [
{"test_namespace": encoded_text + encoded_str1},
{"test_namespace": encoded_text + encoded_str2},
{"test_namespace": encoded_text + encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": base.Embed(str1)},
{"test_namespace": base.Embed(str2)},
{"test_namespace": base.Embed(str3)},
],
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{"test_namespace": str1 + " " + encoded_text + encoded_str1},
{"test_namespace": str2 + " " + encoded_text + encoded_str2},
{"test_namespace": str3 + " " + encoded_text + encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": base.EmbedAndKeep(str1)},
{"test_namespace": base.EmbedAndKeep(str2)},
{"test_namespace": base.EmbedAndKeep(str3)},
],
MockEncoder(),
)
== expected_embed_and_keep
)
def test_action_w_namespace_w_emb2():
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = " ".join(char for char in str1)
encoded_str2 = " ".join(char for char in str2)
encoded_str3 = " ".join(char for char in str3)
expected = [
{"test_namespace1": encoded_text + encoded_str1},
{"test_namespace2": encoded_text + encoded_str2},
{"test_namespace3": encoded_text + encoded_str3},
]
assert (
base.embed(
base.Embed(
[
{"test_namespace1": str1},
{"test_namespace2": str2},
{"test_namespace3": str3},
]
),
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{"test_namespace1": str1 + " " + encoded_text + encoded_str1},
{"test_namespace2": str2 + " " + encoded_text + encoded_str2},
{"test_namespace3": str3 + " " + encoded_text + encoded_str3},
]
assert (
base.embed(
base.EmbedAndKeep(
[
{"test_namespace1": str1},
{"test_namespace2": str2},
{"test_namespace3": str3},
]
),
MockEncoder(),
)
== expected_embed_and_keep
)
def test_action_w_namespace_w_some_emb():
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str2 = " ".join(char for char in str2)
encoded_str3 = " ".join(char for char in str3)
expected = [
{"test_namespace": str1},
{"test_namespace": encoded_text + encoded_str2},
{"test_namespace": encoded_text + encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": str1},
{"test_namespace": base.Embed(str2)},
{"test_namespace": base.Embed(str3)},
],
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{"test_namespace": str1},
{"test_namespace": str2 + " " + encoded_text + encoded_str2},
{"test_namespace": str3 + " " + encoded_text + encoded_str3},
]
assert (
base.embed(
[
{"test_namespace": str1},
{"test_namespace": base.EmbedAndKeep(str2)},
{"test_namespace": base.EmbedAndKeep(str3)},
],
MockEncoder(),
)
== expected_embed_and_keep
)
def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict():
str1 = "test1"
str2 = "test2"
str3 = "test3"
encoded_str1 = " ".join(char for char in str1)
encoded_str2 = " ".join(char for char in str2)
encoded_str3 = " ".join(char for char in str3)
expected = [
{"test_namespace": encoded_text + encoded_str1, "test_namespace2": str1},
{"test_namespace": encoded_text + encoded_str2, "test_namespace2": str2},
{"test_namespace": encoded_text + encoded_str3, "test_namespace2": str3},
]
assert (
base.embed(
[
{"test_namespace": base.Embed(str1), "test_namespace2": str1},
{"test_namespace": base.Embed(str2), "test_namespace2": str2},
{"test_namespace": base.Embed(str3), "test_namespace2": str3},
],
MockEncoder(),
)
== expected
)
expected_embed_and_keep = [
{
"test_namespace": str1 + " " + encoded_text + encoded_str1,
"test_namespace2": str1,
},
{
"test_namespace": str2 + " " + encoded_text + encoded_str2,
"test_namespace2": str2,
},
{
"test_namespace": str3 + " " + encoded_text + encoded_str3,
"test_namespace2": str3,
},
]
assert (
base.embed(
[
{"test_namespace": base.EmbedAndKeep(str1), "test_namespace2": str1},
{"test_namespace": base.EmbedAndKeep(str2), "test_namespace2": str2},
{"test_namespace": base.EmbedAndKeep(str3), "test_namespace2": str3},
],
MockEncoder(),
)
== expected_embed_and_keep
)
def test_one_namespace_w_list_of_features_no_emb():
str1 = "test1"
str2 = "test2"
expected = [{"test_namespace": [str1, str2]}]
assert base.embed({"test_namespace": [str1, str2]}, MockEncoder()) == expected
def test_one_namespace_w_list_of_features_w_some_emb():
str1 = "test1"
str2 = "test2"
encoded_str2 = " ".join(char for char in str2)
expected = [{"test_namespace": [str1, encoded_text + encoded_str2]}]
assert (
base.embed({"test_namespace": [str1, base.Embed(str2)]}, MockEncoder())
== expected
)
def test_nested_list_features_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": [[1, 2], [3, 4]]}, MockEncoder())
def test_dict_in_list_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": [{"a": 1}, {"b": 2}]}, MockEncoder())
def test_nested_dict_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": {"a": {"b": 1}}}, MockEncoder())
def test_list_of_tuples_throws():
with pytest.raises(ValueError):
base.embed({"test_namespace": [("a", 1), ("b", 2)]}, MockEncoder())

View File

@ -0,0 +1,3 @@
class MockEncoder:
def encode(self, to_encode):
return "[encoded]" + to_encode