mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
cleanup
This commit is contained in:
parent
e10980d445
commit
ae5edefdcd
@ -316,7 +316,7 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
- selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None.
|
||||
- policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt.
|
||||
- auto_embed (bool): Determines if embedding should be automatic. Default is False.
|
||||
- metrics (Optional[MetricsTracker]): Tracker for metrics, can be set to None.
|
||||
- metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None.
|
||||
|
||||
Initialization Attributes:
|
||||
- feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs.
|
||||
@ -325,7 +325,8 @@ class RLChain(Chain, Generic[TEvent]):
|
||||
- vw_cmd (List[str], optional): Command line arguments for the VW model.
|
||||
- policy (Type[VwPolicy]): Policy used by the chain.
|
||||
- vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs.
|
||||
- metrics_step (int): Step for the metrics tracker. Default is -1.
|
||||
- metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked.
|
||||
- metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked.
|
||||
|
||||
Notes:
|
||||
The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
|
||||
|
@ -137,7 +137,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
||||
context_matrix = np.stack([v for k, v in context_embeddings.items()])
|
||||
dot_product_matrix = np.dot(context_matrix, action_matrix.T)
|
||||
|
||||
indexed_dot_product: Dict[Dict] = {}
|
||||
indexed_dot_product: Dict = {}
|
||||
|
||||
for i, context_key in enumerate(context_embeddings.keys()):
|
||||
indexed_dot_product[context_key] = {}
|
||||
@ -258,6 +258,18 @@ class PickBest(base.RLChain[PickBestEvent]):
|
||||
):
|
||||
auto_embed = kwargs.get("auto_embed", False)
|
||||
|
||||
feature_embedder = kwargs.get("feature_embedder", None)
|
||||
if feature_embedder:
|
||||
if "auto_embed" in kwargs:
|
||||
logger.warning(
|
||||
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
|
||||
)
|
||||
# turning auto_embed off for cli setting below
|
||||
auto_embed = False
|
||||
else:
|
||||
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
|
||||
kwargs["feature_embedder"] = feature_embedder
|
||||
|
||||
vw_cmd = kwargs.get("vw_cmd", [])
|
||||
if vw_cmd:
|
||||
if "--cb_explore_adf" not in vw_cmd:
|
||||
@ -281,16 +293,6 @@ class PickBest(base.RLChain[PickBestEvent]):
|
||||
|
||||
kwargs["vw_cmd"] = vw_cmd
|
||||
|
||||
feature_embedder = kwargs.get("feature_embedder", None)
|
||||
if feature_embedder:
|
||||
if "auto_embed" in kwargs:
|
||||
logger.warning(
|
||||
"auto_embed will take no effect when explicit feature_embedder is provided" # noqa E501
|
||||
)
|
||||
else:
|
||||
feature_embedder = PickBestFeatureEmbedder(auto_embed=auto_embed)
|
||||
kwargs["feature_embedder"] = feature_embedder
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
|
||||
|
Loading…
Reference in New Issue
Block a user