|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
|
|
|
|
|
|
|
|
import langchain.chains.rl_chain.base as base
|
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
|
|
@ -145,7 +145,6 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
|
|
|
|
|
*args: Any,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
):
|
|
|
|
@ -163,12 +162,14 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"If vw_cmd is specified, it must include --cb_explore_adf"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
kwargs["vw_cmd"] = vw_cmd
|
|
|
|
|
|
|
|
|
|
feature_embedder = kwargs.get("feature_embedder", None)
|
|
|
|
|
if not feature_embedder:
|
|
|
|
|
feature_embedder = PickBestFeatureEmbedder()
|
|
|
|
|
kwargs["feature_embedder"] = feature_embedder
|
|
|
|
|
|
|
|
|
|
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
|
|
|
|
|
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
|
|
|
|
@ -223,10 +224,15 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|
|
|
|
next_chain_inputs = event.inputs.copy()
|
|
|
|
|
# only one key, value pair in event.to_select_from
|
|
|
|
|
value = next(iter(event.to_select_from.values()))
|
|
|
|
|
v = (
|
|
|
|
|
value[event.selected.index]
|
|
|
|
|
if event.selected
|
|
|
|
|
else event.to_select_from.values()
|
|
|
|
|
)
|
|
|
|
|
next_chain_inputs.update(
|
|
|
|
|
{
|
|
|
|
|
self.selected_based_on_input_key: str(event.based_on),
|
|
|
|
|
self.selected_input_key: value[event.selected.index],
|
|
|
|
|
self.selected_input_key: v,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
return next_chain_inputs, event
|
|
|
|
@ -234,7 +240,8 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|
|
|
|
def _call_after_scoring_before_learning(
|
|
|
|
|
self, event: Event, score: Optional[float]
|
|
|
|
|
) -> Event:
|
|
|
|
|
event.selected.score = score
|
|
|
|
|
if event.selected:
|
|
|
|
|
event.selected.score = score
|
|
|
|
|
return event
|
|
|
|
|
|
|
|
|
|
def _call(
|
|
|
|
@ -248,33 +255,19 @@ class PickBest(base.RLChain[PickBest.Event]):
|
|
|
|
|
def _chain_type(self) -> str:
|
|
|
|
|
return "rl_chain_pick_best"
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_chain(
|
|
|
|
|
cls,
|
|
|
|
|
llm_chain: Chain,
|
|
|
|
|
prompt: BasePromptTemplate,
|
|
|
|
|
selection_scorer=SENTINEL,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
):
|
|
|
|
|
if selection_scorer is SENTINEL:
|
|
|
|
|
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
|
|
|
|
return PickBest(
|
|
|
|
|
llm_chain=llm_chain,
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
selection_scorer=selection_scorer,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_llm(
|
|
|
|
|
cls,
|
|
|
|
|
cls: Type[PickBest],
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
prompt: BasePromptTemplate,
|
|
|
|
|
selection_scorer=SENTINEL,
|
|
|
|
|
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
):
|
|
|
|
|
) -> PickBest:
|
|
|
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
|
|
|
return PickBest.from_chain(
|
|
|
|
|
if selection_scorer is SENTINEL:
|
|
|
|
|
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
|
|
|
|
|
|
|
|
|
|
return PickBest(
|
|
|
|
|
llm_chain=llm_chain,
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
selection_scorer=selection_scorer,
|
|
|
|
|