This commit is contained in:
olgavrou 2023-09-04 16:36:29 -04:00
parent e10980d445
commit ae5edefdcd
2 changed files with 16 additions and 13 deletions

View File

@ -316,7 +316,7 @@ class RLChain(Chain, Generic[TEvent]):
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
- auto_embed (bool): Determines if embedding should be automatic. Default is False.
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
- metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None.
Initialization Attributes:
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
@ -325,7 +325,8 @@ class RLChain(Chain, Generic[TEvent]):
- vw_cmd (List[str], optional): Command line arguments for the VW model.
- policy (Type[VwPolicy]): Policy used by the chain.
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
- metrics_step (int): Step for the metrics tracker. Default is -1.
- metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked.
- metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked.
Notes:
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.

View File

@ -137,7 +137,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
context_matrix = np.stack([v for k, v in context_embeddings.items()])
dot_product_matrix = np.dot(context_matrix, action_matrix.T)
indexed_dot_product: Dict[Dict] = {}
indexed_dot_product: Dict = {}
for i, context_key in enumerate(context_embeddings.keys()):
indexed_dot_product[context_key] = {}
@ -258,6 +258,18 @@ class PickBest(base.RLChain[PickBestEvent]):
):
auto_embed = kwargs.get("auto_embed", False)
feature_embedder = kwargs.get("feature_embedder", None)
if feature_embedder:
if "auto_embed" in kwargs:
logger.warning(
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
)
# turning auto_embed off for cli setting below
auto_embed = False
else:
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
kwargs["feature_embedder"] = feature_embedder
vw_cmd = kwargs.get("vw_cmd", [])
if vw_cmd:
if "--cb_explore_adf" not in vw_cmd:
@ -281,16 +293,6 @@ class PickBest(base.RLChain[PickBestEvent]):
kwargs["vw_cmd"] = vw_cmd
feature_embedder = kwargs.get("feature_embedder", None)
if feature_embedder:
if "auto_embed" in kwargs:
logger.warning(
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
)
else:
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
kwargs["feature_embedder"] = feature_embedder
super().__init__(*args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: