no errors in pick best chain

pull/10242/head
olgavrou 1 year ago
parent 6a1102d4c0
commit dd6fff1c62

@ -13,7 +13,7 @@ from langchain.chains.rl_chain.base import (
from langchain.chains.rl_chain.pick_best_chain import PickBest
def configure_logger():
def configure_logger() -> None:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import langchain.chains.rl_chain.base as base
from langchain.base_language import BaseLanguageModel
@ -145,7 +145,6 @@ class PickBest(base.RLChain[PickBest.Event]):
def __init__(
self,
feature_embedder: Optional[PickBestFeatureEmbedder] = None,
*args: Any,
**kwargs: Any,
):
@ -163,12 +162,14 @@ class PickBest(base.RLChain[PickBest.Event]):
raise ValueError(
"If vw_cmd is specified, it must include --cb_explore_adf"
)
kwargs["vw_cmd"] = vw_cmd
feature_embedder = kwargs.get("feature_embedder", None)
if not feature_embedder:
feature_embedder = PickBestFeatureEmbedder()
kwargs["feature_embedder"] = feature_embedder
super().__init__(feature_embedder=feature_embedder, *args, **kwargs)
super().__init__(*args, **kwargs)
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
@ -223,10 +224,15 @@ class PickBest(base.RLChain[PickBest.Event]):
next_chain_inputs = event.inputs.copy()
# only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values()))
v = (
value[event.selected.index]
if event.selected
else event.to_select_from.values()
)
next_chain_inputs.update(
{
self.selected_based_on_input_key: str(event.based_on),
self.selected_input_key: value[event.selected.index],
self.selected_input_key: v,
}
)
return next_chain_inputs, event
@ -234,7 +240,8 @@ class PickBest(base.RLChain[PickBest.Event]):
def _call_after_scoring_before_learning(
self, event: Event, score: Optional[float]
) -> Event:
event.selected.score = score
if event.selected:
event.selected.score = score
return event
def _call(
@ -248,33 +255,19 @@ class PickBest(base.RLChain[PickBest.Event]):
def _chain_type(self) -> str:
return "rl_chain_pick_best"
@classmethod
def from_chain(
cls,
llm_chain: Chain,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
**kwargs: Any,
):
if selection_scorer is SENTINEL:
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
return PickBest(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,
**kwargs,
)
@classmethod
def from_llm(
cls,
cls: Type[PickBest],
llm: BaseLanguageModel,
prompt: BasePromptTemplate,
selection_scorer=SENTINEL,
selection_scorer: Union[base.AutoSelectionScorer, object] = SENTINEL,
**kwargs: Any,
):
) -> PickBest:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return PickBest.from_chain(
if selection_scorer is SENTINEL:
selection_scorer = base.AutoSelectionScorer(llm=llm_chain.llm)
return PickBest(
llm_chain=llm_chain,
prompt=prompt,
selection_scorer=selection_scorer,

Loading…
Cancel
Save